Introduction
Understanding Normalization
In the world of neural networks, reducing covariate shift is important. Covariate shift basically means that the data you trained your model on does not match the data it sees later when you use it. A common thing that happens in Neural Networks is we encounter internal covariate shift. What this basically means is that as we pass through the layers of the network and weights are being tweaked, the data's distribution continuously changes.
This internal covariate shift is not tied to the data's diversity that you are training on, but is tied more to the architecture of the network and how things are behaving as we move through the layers. One of the most common ways that we reduce internal covariate shift is normalization, so the data's distribution is more consistent as we move through the network. Let's look at a brief example of normalization from my Decision Transformer project code:
Think of this as a normalization switch board. It allows us to do either layer normalization, batch normalization or skip it all together. After an option is selected, it applies the chosen normalization to the input value x (a tensor) during the forward pass. It is important to note that there are a bunch of flavors of normalization beyond this. For example, root mean squared is a simple flavor of layer normalization. Instance normalization is another one, where we basically do batch normalization but per instance. The reason this "switchboard" even existed from my teams code was that we were using the transformer for robotic simulation.
In modern transformer architectures like GPT, LLaMA, etc.. layer normalization is the most popular and used quite religiously. For each data point (like a token in a sequence), it computes the mean and variance across the features (e.g., the hidden dimensions) rather than across a batch of examples. Then it scales and shifts the result using learnable parameters.
Think about it like this, if I have a batch of sentences that is 10 and a batch of sentences that is 50, immediately I run into problems using something like batch normalization with transformers... because now the dimensions of my batch are misaligned. Layer normalization will work consistently with no alignment.
For example, let's say I have a sentence that is "Omeed likes to eat potatoes", and I am looking at the "Omeed" token. Hypothetically, let's apply the BERT embedding model to yield a 768 dimension vector representing my name. When I normalize with layer normalization, it'll only look at the embedding vector for my name and calculate the mean and variance for that to normalize - completely independent from the remaining tokens in my batch ("likes", "to", "eat", "potatoes", etc.). By normalizing per token, we essentially are preserving the individual characteristics of each word, which is important to natural language.
To illustrate with a concrete example, let’s examine Andrej Karpathy’s implementation of LayerNorm. What’s brilliant about his approach is that he wrote it in Python, unlike the underlying PyTorch implementation, which relies on low-level languages for optimal performance. This choice makes it much easier to understand:
As you can see... we center the data, scale it, and adjust it with weights and biases. Think messy pile of data to a neat, organized and polished format. To learn more about this snippet, I highly recommend visiting this link.
Now all this normalization talk does come at a cost. During training time and inference time, there is computational overhead to normalization. With batch normalization, there is no "on-the-fly" computation of mean and variance, so the compute cost during training is quite significant. With layer normalization, based on the input size at inference time, there can be significant computational overhead. But to a lot of scientists and researchers, this is a small price to pay for training stability and model performance.
The DyT
So let's start with replacement. In a typical transformer architecture, you will apply layer normalization before the multi head attention, before the feed forward neural network and before the output projection. Using DyT is as simple as going into those parts of your architecture, and replacing them. Now what is actually happening in this? Well you still have your input tensor x. This is multiplied by a learnable scalar α, which can stretch/compress your input values. Think of this as some automatic knob that learned during training and is optimized via back propagation. It is adjusted automatically during gradient descent, a simple way to minimize error by taking tiny steps towards an optimum. We care about this because the paper shows how linear normalization in transformers produces a tanh like mapping:
![]()
Meaning that if we optimize for a good tanh output, because its S-shaped, bounded between -1 and 1 with smooth behavior - it'll naturally replicate linear normalizations key effects. But scalars like α matter because the inputs to this function might have super large or super small values during train time, leading to it being too flat or too saturated - so we are ideally looking for a sweet spot.
![]()
The weight γ starts as an all one vector, bias β starts out as an all zero vector, and the alpha value (our scalar) starts out as .5, which is a value predefined in the paper. In the forward pass, it is all simply put together, and matches the formula we discussed above.
![]()