The Dynamic TanH (DyT)

Introduction

A really interesting new paper by Yann LeCun, Jiachen Zhu, Xinlei Chen, Kaiming He, and Zhuang Liu (https://arxiv.org/abs/2503.10622) introduces a novel approach to replace normalization layers in Transformers with a simple operation called Dynamic Tanh (DyT) - which is challenging the need for normalization all together.  

Understanding Normalization

In the world of neural networks, reducing covariate shift is important. Covariate shift basically means that the data you trained your model on does not match the data it sees later when you use it. A common thing that happens in Neural Networks is we encounter internal covariate shift. What this basically means is that as we pass through the layers of the network and weights are being tweaked, the data's distribution continuously changes. 

This internal covariate shift is not tied to the data's diversity that you are training on, but is tied more to the architecture of the network and how things are behaving as we move through the layers. One of the most common ways that we reduce internal covariate shift is normalization, so the data's distribution is more consistent as we move through the network. Let's look at a brief example of normalization from my Decision Transformer project code: 

Think of this as a normalization switch board. It allows us to do either layer normalization, batch normalization or skip it all together. After an option is selected, it applies the chosen normalization to the input value x (a tensor) during the forward pass. It is important to note that there are a bunch of flavors of normalization beyond this. For example, root mean squared is a simple flavor of layer normalization. Instance normalization is another one, where we basically do batch normalization but per instance. The reason this "switchboard" even existed from my teams code was that we were using the transformer for robotic simulation.

In modern transformer architectures like GPT, LLaMA, etc.. layer normalization is the most popular and used quite religiously. For each data point (like a token in a sequence), it computes the mean and variance across the features (e.g., the hidden dimensions) rather than across a batch of examples. Then it scales and shifts the result using learnable parameters. 

Think about it like this, if I have a batch of sentences that is 10 and a batch of sentences that is 50, immediately I run into problems using something like batch normalization with transformers... because now the dimensions of my batch are misaligned. Layer normalization will work consistently with no alignment. 

For example, let's say I have a sentence that is "Omeed likes to eat potatoes", and I am looking at the "Omeed" token. Hypothetically, let's apply the BERT embedding model to yield a 768 dimension vector representing my name. When I normalize with layer normalization, it'll only look at the embedding vector for my name and calculate the mean and variance for that to normalize - completely independent from the remaining tokens in my batch ("likes", "to", "eat", "potatoes", etc.). By normalizing per token, we essentially are preserving the individual characteristics of each word, which is important to natural language. 

To illustrate with a concrete example, let’s examine Andrej Karpathy’s implementation of LayerNorm. What’s brilliant about his approach is that he wrote it in Python, unlike the underlying PyTorch implementation, which relies on low-level languages for optimal performance. This choice makes it much easier to understand:

As you can see... we center the data, scale it, and adjust it with weights and biases. Think messy pile of data to a neat, organized and polished format. To learn more about this snippet, I highly recommend visiting this link.

Now all this normalization talk does come at a cost. During training time and inference time, there is computational overhead to normalization. With batch normalization, there is no "on-the-fly" computation of mean and variance, so the compute cost during training is quite significant. With layer normalization, based on the input size at inference time, there can be significant computational overhead. But to a lot of scientists and researchers, this is a small price to pay for training stability and model performance. 

The DyT

Now imagine you can go into your code, replace all layer normalizations with DyT, and significantly cut latency and memory usage. Let's first just start with the math. It'll be easier to explain after you've seen the equation representing Dynamic TanH:


So let's start with replacement. In a typical transformer architecture, you will apply layer normalization before the multi head attention, before the feed forward neural network and before the output projection. Using DyT is as simple as going into those parts of your architecture, and replacing them. Now what is actually happening in this? Well you still have your input tensor x. This is multiplied by a learnable scalar α, which can stretch/compress your input values. Think of this as some automatic knob that learned during training and is optimized via back propagation. It is adjusted automatically during gradient descent, a simple way to minimize error by taking tiny steps towards an optimum. We care about this because the paper shows how linear normalization in transformers produces a tanh like mapping: 
Meaning that if we optimize for a good tanh output, because its S-shaped, bounded between -1 and 1 with smooth behavior - it'll naturally replicate linear normalizations key effects. But scalars like α matter because the inputs to this function might have super large or super small values during train time, leading to it being too flat or too saturated - so we are ideally looking for a sweet spot. 

Without over complicating it, γ and β are learnable vectors... and simply optimized during training via back propagation. Putting it all together, this is the final code snippet. 

The weight γ starts as an all one vector, bias β starts out as an all zero vector, and the alpha value (our scalar) starts out as .5, which is a value predefined in the paper. In the forward pass, it is all simply put together, and matches the formula we discussed above. 

For years, we believed normalization was indispensable for stable neural network training, but this groundbreaking discovery of Dynamic Tanh has me captivated by the innovative ways researchers will now design efficient, high-performing models.