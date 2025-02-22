Authors: (1) David Raposo, Google DeepMind and with equal contribution; (2) Sam Ritter, Google DeepMind; (3) Blake Richards, Google DeepMind and McGill University & Mila; (4) Timothy Lillicrap, Google DeepMind; (5) Peter Conway Humphreys, Google DeepMind; (6) Adam Santoro, Google DeepMind and with equal contribution.

5. Discussion

Mixture-of-Depths transformers empirically demonstrate that one can improve on isoFLOP-optimal baseline performance with models that use fewer FLOPs per forward pass. This means that—for a given training FLOP budget—we can train models that are both faster and better performing than their baseline counterparts. Previously, to train models that are both faster and as- or better-performing than isoFLOP-optimal models, one would have to use surplus compute to overtrain smaller models (notably, this overtraining technique is still possible with MoD transformers, and speed gains should compound).





While MoD transformers require fewer FLOPs per forward pass, one cannot forego FLOPs indiscriminately. Rather, it is crucial to use learned routing decisions—much like in Mixture-of-Experts transformers—to determine whether a token should participate in self-attention and the subsequent MLP (requiring FLOPs), or not (saving FLOPs).We can then use any saved FLOPs by, for example, making the model bigger or training it for longer. Our results show that indeed FLOPs may be inefficiently used in vanilla transformer models, and that there may be more efficient ways for them to be expended.





Learned routing mechanisms are sometimes non-causal; that is, information about the future is used to determine a given token’s routing decision. This is generally true for top-k routing mechanisms, which are useful because they forego the need for auxiliary balancing losses. However, top-k routing mechanisms present difficulties in post-training autoregressive sampling, where it is impossible to use information about future token identities to determine routing decisions. In this work we show that one can successfully use a top-k routing scheme during training, but not require it during later autoregressive sampling. Eiher a simple auxiliary classifier, or auxiliary loss on the router, is sufficient to learn the top-𝑘 routing decisions such that it can mimic the top-𝑘 decisions during autoregressive









sampling, with minimal to no performance degradation.





Intuitively, a token might learn to route around blocks because the prediction being made at that step is easier, and hence, does not require as much compute. However, this strategy is undoubtedly not all that the network learns. If a token does not participate in self-attention at a certain block, then later tokens will also not be able to attend to it. Thus, whether tokens decide to route or not impacts both the current step’s prediction and future predictions via causal self-attention, and how the network balances these effects is guided by their influence on the overall language modeling objective.





This insight opens the door to MoD variants that decouple the routing for queries, keys and values. For example, perhaps a token would prefer to be among the queries, but not the keys, for a given self-attention computation. One can imagine extending this idea even further into the domain of "long-term memory": perhaps there are tokens that would be extremely valuable as keys, regardless of whether it is useful for them to also be among the queries at the step of their occurrence. Learned routing could be a powerful mechanism for deciding which tokens these might be, perhaps funnelling them into a long-term memory buffer that is available during future self-attention. One advantage of such an approach to long-term memory is that tokens decide once, at the moment of "memory encoding", whether they should be retrieved in the future. This is more computationally efficient than performing a full content-based lookup across an entire memory buffer for each step in the future, and could be one step towards drastically increasing the context-length available for making a prediction.





Unlike MoE transformers that route between effectively the same computation (usually MLPs), MoD transformers demonstrate the value of routing among different types of computations. In this work the types were either the conventional transformer block, or a null computation (functionally equivalent to multiplying by zero). However, one can imagine extending this idea further by routing between even more types of computation. For example, perhaps some tokens are routed to "memory lookup" functions, and others are routed to "tool use" functions. In general, the routing machinery we deployed provides a knob for adjusting the types of computations available to the network and their relative cost (in total FLOPs); if one wants to introduce an expensive computation, then this can be offset by setting its capacity to some small amount, and hence, by routing only a small number of tokens to it.





Altogether, MoD transformers are another tool one can use to tune a model’s compute per forward pass (and hence inference time). The machinery used to implement MoD is also generic, and opens the doors to many extensions and integration with other techniques, such as MoE.

