Transformer Architecture
Transformer architectures, originally introduced by Vaswani et al. 2017 in "Attention is all you need", are the foundation for many of the most advanced NLP algorithms in use today. This neural network design is quite intricate, with the level of complexity scaling up with the complexity of sequence-related tasks.
- RNNs had issues with vanishing gradients, which made it hard to capture long range dependencies.
- LSTM and GRU ith their gated mechanisms, significantly improved the management of information flow, addressing some of the fundamental shortcomings of traditional RNNs.
Despite these advances, these architectures are still inherently sequential, processing inputs one unit at a time. This sequential processing acts as a bottleneck, with each unit's output dependent on the calculation of all preceding units.
Contrast this with a ConvNet that can a lot of pixels/lot of words at once and can compute representations for them in parallel.
The major innovation of the transformer architecture is in integrating attention-based representations with a processing style similar to CNNs, handling multiple sequence elements concurrently. This parallel processing capability means that an entire sentence can be taken in at once and computed in parallel, rather than being digested word by word, from left to right.
Two main concepts in transformers are self-attention and multi-head attention, which together facilitate the generation of contextually rich word representations.
- Self-attention: This mechanism computes contextual representations for each word in a sentence, taking into account the entire sequence. For a five-word sentence, you obtain five such representations .
- Multi-head attention: Essentially, this involves a for loop over the self-attention process, each time capturing different aspects of the sequence's information. The result is a set of rich, nuanced representations that enhance the model's performance across various NLP tasks, including machine translation.
Self-Attention
The power of the Transformer architecture lies in its ability to learn the relevance and context of all of the words in the prompt with each other. The context is learnt not just with their immediate neighbour, but with every other word.
For example, in the above image, the model learns how the word teacher is associated with every other word - The, taught, the, student, etc.
The model applies "attention weights" to the relationships so that it learns the relevance of each word to every other word. These "attention weights" are learned during training. This is called self-attention. The term originates from the fact that each word in the prompt attends to other words in the same prompt, including itself. This mechanism is what enables Transformers to capture relationships and dependencies between words regardless of their distance from each other in the prompt.
We've seen attention used in sequential neural networks such as RNNs. To use attention with a simultaneous input processing style like ConvNets, you need to calculate self-attention, where you create an attention-based representations for each of the words in the input sentence. Take this sentence for example, for each word, we aim to compute an attention-based vector representation, denoted by .
Jane | visite | l'Afrique | en | septembre |
---|---|---|---|---|
Our goal will be for each word to compute an attention-based representation for each word. So, we'll end up with five of these, since our sentence has five words .
We've used word embeddings in a similar way. One way to represent l'Afrique would be to just look up the word embedding for it. But depending on the context, this representation could vary: are we thinking of l'Afrique as a site of historical interests, a holiday destination, or as the world's second largest continent?
Depending on how you're thinking of l'Afrique, you may choose to represent it differently, and that's what this representation will do. It will look at the surrounding words to try to figure out what's actually going on and how we're talking about Africa in this sentence, and find the most appropriate representation for this.
The computation process for self-attention in transformers isn't isn't too different from the attention mechanism applied to the sequential computation in RNNs, except we'll parallelize it across all words in a sentence. When we did sequential RNN attention:
With the self-attention mechanism, the attention equation becomes:
You can see the equations have some similarity, particularly in the softmax denominator. You can think of the exponent terms as being akin to attention values. The main difference is the transformer's use of queries (), keys (), and values () for each word. These vectors are the key inputs to computing the attention value for each words. First, we associate each of the words with three vectors:
These vectors are derived from the input through learnable weight matrices , , and , providing the necessary components to calculate attention values for each word.
For intuition, what are these query, key and value vectors supposed to do? They were named using a loose analogy to the concept of a databases where you can have queries and also key-value pairs.
To illustrate, the query vector is a question that you get to ask about l'Afrique. may represent a question like, "what's happening there?". What we're going to do is compute the inner product between Query 3 () and Key 1 () and this will tell us how good of an answer Word 1 () is to the question of what's happening in Africa.
Then we compute the inner product between and and this is intended to tell us how good of an answer visite is to the question of what's happening in Africa and so on for the other words in the sequence. The goal of this operation is to pull up the most information that's needed to help us compute the most useful attention based representation for this word, .
Again, just for intuition building, if represents that this word is a person, because Jane is a person, and represents that the second word, visite, is an action, then you may find that has the largest value, and this may be an intuitive example which suggests that visite, gives you the most relevant context for what's happening in Africa i.e. It's viewed as a destination for a visit.
Overall, the inner products between and all keys () across the sequence gauge how relevant each word is to the context of 'Africa'. This relevance is then weighted and summed up with the corresponding value vectors (), leading to a context-aware representation of as .
For the five attention based representations for these word, , we can write the equation as:
Step by step, what we will do is take the products between and the other 's and compute a Softmax over them. Then finally, we're going to take these Softmax values and multiply them with and we take the element-wise sum of these for , or more formally, .
This shows that a word does not have a fixed representation and can actually adapt to how it is used in the sentence.
This needs to be done for every word. You can summarize all of these computations for all the words in the sequence by writing where , , matrices have all these values:
The formula given above represents a vectorized form of the individual attention computations. The normalization term, , scales the dot products, preventing them from growing too large and exploding. Another name for this type of attention is the scaled dot-product attention.
Ultimately, this method offers a nuanced and adaptive representation for each word, far surpassing the limitations of fixed word embeddings by incorporating contextual information from the words on both sides of each word in the sequence.
Summary
Overall:
- We feed to the network, where is the context window length and is the dimensions of the embedding.
- We project into three matrices , and :
- , where is a matrix of dimension .
- , where is a matrix of dimension .
- , where is a matrix of dimension .
- We compute the attention using the following vectorized equation:
and need to have the same dimension since we take a dot product between and . The output dimension depends on the dimension of .
For a full picture of the dimensions:
Multi-Head Attention
This process involves computing self-attention multiple times, with each computation known as a "head". Intuitively, we have multiple questions we'd like to find the best answer for.
The number of attention heads included in the attention layer varies from model to model, but numbers in the range of 12-100 are common. The intuition here is that each self-attention head will learn a different aspect of language, such as the relationship between the people entities in our sentence, the activity of the sentence or even other properties such as if the words rhyme.
While the implementation differs for efficiency reasons, there are essentially number of , and matrices, one for each question we'd like to answer. is called the number of heads.
Self-attention is computed with each of these matrices, to obtain an matrix. The and dimensions are concatenated to get an matrix. This is finally multiplied with an matrix to obtain the final output of dimension .
How Multi-Head Attention Works
The vectors , , and are initially derived for each input term, by multiplying them with weight matrices , , and . In the context of multi-head attention, you take that same set of vectors as inputs, but are further processed with additional weight matrices to generate new sets of query, key, and value vectors:
So with this computation, the word visite gives the best answer to "what's happening", which is why it is highlighted with this blue arrow to represent that the inner product between the key for visite has the highest value with the query for l'Afrique. You do the same for the remaining words so you end up with five vectors to represent the five words in the sequence.
Iterative Process of Multi-Head Attention
Rather than performing this operation just once, the transformer model repeats it multiple times, producing multiple "heads" of attention. Each head can be thought of as focusing on different features, or asking different questions, within the sentence, thus building a comprehensive representation.
For example, the model might use eight different sets of weight matrices , , and to perform this computation eight times, resulting in eight different attention outputs, which just means performing this whole calculation maybe eight times:
where each individual head is defined as:
and denotes the number of heads used in the process:
Computation
The idea is to stack all the weight matrices required for computing the , and matrices for each head into one single matrix. This ensures that we can obtain the , and matrices using a single matrix multiplication instead of multiple multiplications.
Consider that each , and matrix will have (say). Suppose we have number of heads. Thus, we need number of matrices ( for , and , and for each head). In other words, we need a matrix, where represents the stacked matrix dimension. Let this matrix be .
We then multiply with as follows:
We then reshape this to obtain an tensor. Finally, we can take chunks of three from the last dimension to obtain matrices, each representing the , and matrices.
These three matrices are passed to the self-attention block to obtain an output, which is concatenated along the last dimension to obtain the output . Finally, is multiplied with an () matrix to obtain the final output as follows:
In the actual implementation, we pass in three inputs:
- - Input embedding size.
- - Number of heads.
- - Expected output dimension of multi-headed attention.
From this, is computed as since . In other words, should be such that . The rest is the same as above.
In PyTorch, this is implemented as follows:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def scaled_dot_product(q, k, v, mask=None):
d_k = q.size()[-1]
attn_logits = torch.matmul(q, k.transpose(-2, -1))
attn_logits = attn_logits / math.sqrt(d_k)
# Applying mask for masked multi-headed attention
if mask is not None:
attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
attention = F.softmax(attn_logits, dim=-1)
values = torch.matmul(attention, v)
return values, attention
def expand_mask(mask):
assert (
mask.ndim > 2
), "Mask must be at least 2-dimensional with seq_length x seq_length"
if mask.ndim == 3:
mask = mask.unsqueeze(1)
while mask.ndim < 4:
mask = mask.unsqueeze(0)
return mask
class MultiheadAttention(nn.Module):
def __init__(self, d, dO, h):
super().__init__()
assert dO % h == 0, "Embedding dimension must be 0 modulo number of heads."
self.dO = dO
self.h = h
# Compute dh
self.dh = dO // h
# Create the stacked weight matrix using a linear layer
# It will receive a d-dim input
# And produce a 3.h.dh = 3.dO dimensional output
self.qkv_proj = nn.Linear(d, 3 * dO)
# Create WO using a linear layer
# It will receive an h.dh = dO-dim input and
# Produce a dO-dim output
self.o_proj = nn.Linear(dO, dO)
self._reset_parameters()
def _reset_parameters(self):
# Original Transformer initialization, see PyTorch documentation
nn.init.xavier_uniform_(self.qkv_proj.weight)
self.qkv_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.o_proj.weight)
self.o_proj.bias.data.fill_(0)
def forward(self, x, mask=None, return_attention=False):
batch_size, seq_length, _ = x.size()
if mask is not None:
mask = expand_mask(mask)
# [Batch, L, 3*h.dh] = [Batch, L, 3.dO]
qkv = self.qkv_proj(x)
# Reshape to [Batch, L, h, 3*dh]
qkv = qkv.reshape(batch_size, seq_length, self.h, 3 * self.dh)
# Permute as [Batch, h, L, 3*dh]
qkv = qkv.permute(0, 2, 1, 3)
# Take out [Batch, h, L, dh] chunks to obtain Q, K and V
q, k, v = qkv.chunk(3, dim=-1)
# Apply self-attention - [Batch, h, L, dh]
values, attention = scaled_dot_product(q, k, v, mask=mask)
# Permute to [Batch, L, h, dh]
values = values.permute(0, 2, 1, 3)
# Concatenate to [Batch, L, h.dh] = [Batch, L, dO]
values = values.reshape(batch_size, seq_length, self.dO)
# Multiply with WO for final output
o = self.o_proj(values)
return (o, attention) if return_attention else o
Parallel Computation Advantage
A key advantage of this design is that each head operates independently, allowing for several attention layers to run runin parallel. This parallelism facilitates efficiency and speeds up the training and inference processes compared to a sequential for-loop approach. Once all heads have been computed, their outputs are concatenated and then transformed by multiplying with the weight matrix to produce the final multi-head attention output.
This approach enables the transformer model to capture a more nuanced understanding of the input sequence, making it highly effective for complex tasks in NLP.
Transformer Network
This is the simplified overall structure of the encoder (left) and decoder (right) of the Transformer model, and you stack this layer times. In one layer of Transformer, there are three multi-head attention, which are displayed as boxes in orange.
To address sequence translation tasks, the model incorporates tokens to signify the start (<SOS>
) and end (<EOS>
) of sentences, crucial to define sequence boundaries.
Jane | visite | l'Afrique | en | septembre | ||
---|---|---|---|---|---|---|
<SOS> | <EOS> |
The first step in the transformer is these embeddings get fed into an encoder block which has a multi head attention layer. This stage utilizes matrices , , and , derived from embeddings and corresponding weight matrices .
This layer then produces a matrix that is passed into a feed forward neural network, which helps determine what interesting features there are in the sentence. In the transformer paper this encoding block is repeated times, typically with .
After times through the encoder block, we will then feed the encoder output into a decoder block. The decoders block's job is output the English translation. The decoder initiates with the <SOS>
token.
The token gets fed into this first multi-head attention block. Just this one <SOS>
token is used to compute , and for this first multi-headed attention block in the decoder.
The first multi-head attention output within the decoder generates the matrix for the subsequent attention block, while and come straight from the encoder's output. This structured design ensures that the decoding process leverages the context from both the already translated sequence and the encoder's representation of the source sentence.
Why is it structured this way? One piece of intuition that could be taken is that that the input at the bottom of the decoder is what you have translated of the sentence so far. It will then pull context from and , which is translated from the French version of the sentence, to then try to decide what is the next word in the sequence to generate.
To finish the description of the decoder block, the multi-head attention block outputs the values which are fed to feed forward neural network. This decoder block is also repeated N times, maybe six times (where you take the output and feed it back to the input).
The job of this new network is to predict the next word in the sentence. So hopefully it will decide that the first word in the English translation is Jane and what we do is then feed Jane to the input as well. The next query comes from <SOS>
and Jane and it says what is the most appropriate next word using the and context from the encoder.
But beyond these main ideas, there are a few extra bells and whistles missing, such as positional encoding, add & norm, and masked multi-head attention.
Positional Encoding
A unique aspect of the transformer model is its use of positional encoding, necessary due to the self-attention mechanism's insensitivity to word order. The position within the sentence can be extremely important to translation. The position is encoded using the following function:
where () is the numerical position of the word being encoded and () is an index into the embedding for the word.
Take the sentence:
Jane visite l'Afrique en Septembre
Consider the word Jane. Here, . Assuming 4-dimensional word embedding () with (), the positional encoding of Jane would be:
This creates a vector with alternating sine and cosine waves, which have different frequencies. This vector is added to the original embedding for the word so that the original embedding also has information about the word's position.
Each word is paired with a positional embedding vector of the same dimensionality, capturing its unique position within the sentence.
A 4 dimensional vector like this:
Exists for all words . In this example, we're going to then create a positional embedded vector of the same dimension for all words and call this positional embedding . In the equation, denotes the numerical position of the word. So, for the word Jane, is equal to 1. refers to the different dimensions of the encoding i.e:
What position encoding does with sine and cosine, is create a unique position encoding vector for each word. So, the vector that encodes the position of l'Afrique, the third word, will be a set of four distinct values. They'll be different to the four values used in the positional encoding of the position of the first word Jane.
Positional encodings are directly added to the input embeddings , so that each of the word vectors is also influenced by where in the sentence the word appears.
The output of the encoding block contains contextual semantic embedding and positional encoding information. The outputs of the embedding layer is then , which in this case 4, by the maximum length the sequence can take. The outputs of all these layers in the encoder and decoder are also of this shape.
In addition to adding these position encodings to the embeddings, you'd also pass them through the network with residual connections, similar to those you see in ResNet. Their purpose in this case is to pass along positional information through the entire architecture.
Residual Connections
The output matrix of the multi-headed attention is added to the original embedded input using a residual connection.
This requires the output dimension of the multi-headed attention layer to match the original dimension of the input. In other words, so that the output is .
This residual connection is important since:
- It helps with the depth of the model by allowing information to be passed across greater depths.
- Multi-headed attention does not have any information about the position of tokens in the input sequence. With the residual connection (and [[#Positional Encoding|positional encoding]]), it is possible to pass this information to the rest of the model instead of the information being lost after the first multi-headed attention pass. It gives the model a chance to distinguish which information came from which element of the input sequence.
Add & Norm
Similar to batch normalization, the Add & Norm layer feature facilitates the passing positional information throughout the network. This is preferred over batch normalization since the batch size is often small and batch normalization tends to perform poorly with text since words tend to have a high variance (due to rare words being considered for a good distribution estimate).
This batch-norm like layer (Add & Norm layer) is repeated throughout this architecture. Layer normalization:
- Speeds up training.
- Provides a small regularization.
- Ensures features are in a similar magnitude among the elements in the input sequence.
Feed-Forward Neural Network (FFNN)
The normalized output is fed to an FFNN. It is applied to each element in the input sequence separately and identically. The encoder uses a Linear ReLU Linear model. Usually, the inner dimension of the FFNN is 2-8 larger than , the size of the input embedding.
The FFNN adds complexity to the model and can be thought of as an extra step of "pre-processing" applied on the output of multi-headed attention.
There is also a residual connection between the output of the multi-headed attention and the output of the FFNN, with layer normalization.
Masked Multi-Head Attention
So previously we've looked through how the transformer performs prediction, one word at the time, but how does it train? Masked Multi-Head Attention is crucial during training - this component ensures the model cannot prematurely access future words in the sequence, simulating the conditions of sequence prediction.
During the training process where you're using a data set of correct French to English translations to train your transformer, you'd need to mask these words.
Let's say the data set has the correct French to English translation, "Jane Visite l'Afrique en Septembre" and "Jane visits Africa in September". When training you have access to the entire correct English translation, the correct output, and the respective correct input. And because you have the full correct output you don't have to generate the words one at a time during training.
Instead, what masking does is it blocks out the last part of the sentence to mimic what the network will need to do at test time/during prediction. In other words, all that mask multi-head attention does is repeatedly pretends that the network had perfectly translated say the first few words and hides the remaining words to see if given a perfect first part of the translation, whether the neural network can predict the next word in the sequence accurately.
In detail, a part of the matrix that is obtained after computing is masked. In particular, a word at index should only attend to words from indices to . Thus, all indices from to are set to so that they become when the softmax is applied. Consider the sentence:
Je suis un étudiant
After computing and applying softmax, the matrix might look like this:
Je | suis | un | étudiant | |
---|---|---|---|---|
Je | ||||
suis | ||||
un | ||||
étudiant |
For suis, . Thus, the elements at indices and are since suis should attend only to Je and suis itself. On the other hand, the word étudiant should attend to all the words since its the last word. Thus, no element in that row is .
Note that the output is shifted to the right during inference, where we do not have the entire output sequence and are actually predicting one word at a time. We start the decoder with the single token <SOS>
as the input and then, as the decoder predicts the next word, we add this new word to the input. This is what's referred to as "shifted right" in the diagram.
Cross-Attention
Cross-attention is a generalization of self-attention. In self-attention, we are computing the attention matrix for a single sequence. The same sequence is used to compute the , and matrices.
In cross-attention, we have two different sequences and . is used to compute the matrix while is used to compute the and matrices. When , cross-attention reduces to self-attention.
In the decoder, cross-attention is used in the second multi-headed attention layer. The matrix comes from the output of the first multi-headed attention layer, which is turn uses the input to the decoder as its input. The and matrices are computed from the output of the final encoder block.
In essence, the input to the decoder is acting as and the output of the final encoder block is acting as .
Cross-attention can work with sequences of different lengths. When computing from , we get an matrix. When computing from , we get an matrix. When we take the dot product of and , we get an matrix. Since is also computed from , its dimension is . Thus, the overall result of cross-attention will be after softmax and multiplying with . See the image below ( and ):
This technique of cross-attention is also used in diffusion models. See High-Resolution Image Synthesis with Latent Diffusion Models (opens in a new tab).
Linear, Softmax & Prediction
There is a final linear layer followed by a softmax activation. This converts the output of the decoder to a probability distribution over all the words. In other words, if the dictionary of words has words, this layer has units, for each word.
The next word can be predicted from this probability distribution by either taking the one with the maximum probability or using other techniques. These other techniques can affect how creative the model is. For example, technique might lead to the model not choosing the most "obvious" word every time and going for slightly eccentric choices.
Greedy Sampling
The technique of using the word with the maximum probability is called greedy sampling. This is the most commonly used technique for many models. Consider the following softmax output:
Probability | Word |
---|---|
cake | |
donut | |
banana | |
apple | |
The model would output the word cake since it has the highest probability.
Random Sampling
Another approach is called random-weighted sampling. It introduces some variability to the model's output. Instead of taking the word with the maximum probability, the probabilities are used as weights to sample one word at random. For example, consider the following softmax output:
Probability | Word |
---|---|
cake | |
donut | |
banana | |
apple | |
The word cake has a 20% chance of being selected while the word banana has a 2% chance of being selected. It might be that the model selects the word banana.
It is possible that this technique leads to the model becoming too creative, where it generates words or wanders into topics that do not make sense with respect to the prompt.
Types of Configurations of Transformers
Encoder-only Models
They only have encoders. Without some changes, these models always produce an output which is of the same length as the input sequence. It is possible to modify these so that they can be used for tasks such as sentiment analysis. Examples of such models are BERT.
Encoder-Decoder Models
This is the model originally described in the Transformers paper and the one detailed here. The output sequence and the input sequence can be of different lengths. It is useful for sequence-to-sequence tasks such as machine translation. Examples of such models are BART and FLAN-T5.
Decoder-only Models
These are some of the most commonly used models. As they have scaled, they have gained the ability to generalize to pretty much any task. Examples of such models are GPT, BLOOM, LLaMA.
Resources:
- Vaswani et al. 2017, Attention is all you need (opens in a new tab)
- Multi-head attention mechanism: “queries”, “keys”, and “values,” over and over again (opens in a new tab)
- Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch (opens in a new tab).
- Tutorial 6: Transformers and Multi-Head Attention (opens in a new tab).
- Lecture on Transformers from the Course (opens in a new tab).
- Lectures on Transformers from Deep Learning Specialization's Sequence Model course on Coursera (opens in a new tab).
- Layer Normalization (opens in a new tab).