Artificial Intelligence 🤖
Recurrent Neural Networks (RNNs)
Transformer Networks

Transformer Architecture

Transformer architectures, originally introduced by Vaswani et al. 2017 in "Attention is all you need", are the foundation for many of the most advanced NLP algorithms in use today. This neural network design is quite intricate, with the level of complexity scaling up with the complexity of sequence-related tasks.

RNN, GRU and LSTM

  • RNNs had issues with vanishing gradients, which made it hard to capture long range dependencies.
  • LSTM and GRU ith their gated mechanisms, significantly improved the management of information flow, addressing some of the fundamental shortcomings of traditional RNNs.

Despite these advances, these architectures are still inherently sequential, processing inputs one unit at a time. This sequential processing acts as a bottleneck, with each unit's output dependent on the calculation of all preceding units.

Contrast this with a ConvNet that can a lot of pixels/lot of words at once and can compute representations for them in parallel.

The major innovation of the transformer architecture is in integrating attention-based representations with a processing style similar to CNNs, handling multiple sequence elements concurrently. This parallel processing capability means that an entire sentence can be taken in at once and computed in parallel, rather than being digested word by word, from left to right.

RNN & CNN

Two main concepts in transformers are self-attention and multi-head attention, which together facilitate the generation of contextually rich word representations.

  1. Self-attention: This mechanism computes contextual representations for each word in a sentence, taking into account the entire sequence. For a five-word sentence, you obtain five such representations [A<1>,,A<5>][A^{<1>}, \cdots , A^{<5>}].
  2. Multi-head attention: Essentially, this involves a for loop over the self-attention process, each time capturing different aspects of the sequence's information. The result is a set of rich, nuanced representations that enhance the model's performance across various NLP tasks, including machine translation.

Self-Attention

The power of the Transformer architecture lies in its ability to learn the relevance and context of all of the words in the prompt with each other. The context is learnt not just with their immediate neighbour, but with every other word.

self-attention-example-sentence

For example, in the above image, the model learns how the word teacher is associated with every other word - The, taught, the, student, etc.

The model applies "attention weights" to the relationships so that it learns the relevance of each word to every other word. These "attention weights" are learned during training. This is called self-attention. The term originates from the fact that each word in the prompt attends to other words in the same prompt, including itself. This mechanism is what enables Transformers to capture relationships and dependencies between words regardless of their distance from each other in the prompt.

We've seen attention used in sequential neural networks such as RNNs. To use attention with a simultaneous input processing style like ConvNets, you need to calculate self-attention, where you create an attention-based representations for each of the words in the input sentence. Take this sentence for example, for each word, we aim to compute an attention-based vector representation, denoted by A(q,K,V)A(q, K, V).

Janevisitel'Afriqueenseptembre
x<1>x^{<1>}x<2>x^{<2>}x<3>x^{<3>}x<4>x^{<4>}x<5>x^{<5>}
A<1>A^{<1>}A<2>A^{<2>}A<3>A^{<3>}A<4>A^{<4>}A<5>A^{<5>}

Our goal will be for each word to compute an attention-based representation for each word. So, we'll end up with five of these, since our sentence has five words A<1>A<5>A^{<1>} \cdots A^{<5>}.

We've used word embeddings in a similar way. One way to represent l'Afrique would be to just look up the word embedding for it. But depending on the context, this representation could vary: are we thinking of l'Afrique as a site of historical interests, a holiday destination, or as the world's second largest continent?

Depending on how you're thinking of l'Afrique, you may choose to represent it differently, and that's what this representation A<3>A^{<3>} will do. It will look at the surrounding words to try to figure out what's actually going on and how we're talking about Africa in this sentence, and find the most appropriate representation for this.

The computation process for self-attention in transformers isn't isn't too different from the attention mechanism applied to the sequential computation in RNNs, except we'll parallelize it across all words in a sentence. When we did sequential RNN attention:

α<t,t>=exp(e<t,t>)t=1Txexp(e<t,t>)\alpha^{<t,t'>} = \frac{\exp(e^{<t,t'>})}{\sum_{t'=1}^{T_x} \exp(e^{<t,t'>})}

With the self-attention mechanism, the attention equation becomes:

A(q,K,V)=iexp(qk<i>)jexp(qk<j>)v<i>A(q, K, V) = \sum_{i} \frac{\exp(q \cdot k^{<i>})}{\sum_{j} \exp(q \cdot k^{<j>})} \cdot v^{<i>}

You can see the equations have some similarity, particularly in the softmax denominator. You can think of the exponent terms as being akin to attention values. The main difference is the transformer's use of queries (qq), keys (kk), and values (vv) for each word. These vectors are the key inputs to computing the attention value for each words. First, we associate each of the words with three vectors:

q<t>=WQx<t>k<t>=WKx<t>v<t>=WVx<t>\begin{align*} q^{<t>} &= W^Q \cdot x^{<t>} \\ k^{<t>} &= W^K \cdot x^{<t>} \\ v^{<t>} &= W^V \cdot x^{<t>} \end{align*}

These vectors are derived from the input through learnable weight matrices WQW^Q, WKW^K, and WVW^V, providing the necessary components to calculate attention values for each word.

For intuition, what are these query, key and value vectors supposed to do? They were named using a loose analogy to the concept of a databases where you can have queries and also key-value pairs.

To illustrate, the query vector q<3>q^{<3>} is a question that you get to ask about l'Afrique. q<3>q^{<3>} may represent a question like, "what's happening there?". What we're going to do is compute the inner product between Query 3 (q<3>q^{<3>}) and Key 1 (k<1>k^{<1>}) and this will tell us how good of an answer Word 1 (x<1>x^{<1>}) is to the question of what's happening in Africa.

Then we compute the inner product between q<3>q^{<3>} and k<2>k^{<2>} and this is intended to tell us how good of an answer visite is to the question of what's happening in Africa and so on for the other words in the sequence. The goal of this operation is to pull up the most information that's needed to help us compute the most useful attention based representation for this word, A<3>A^{<3>}.

Again, just for intuition building, if k<1>k^{<1>} represents that this word is a person, because Jane is a person, and k<2>k^{<2>} represents that the second word, visite, is an action, then you may find that q<3>k<2>q^{<3>} \cdot k^{<2>} has the largest value, and this may be an intuitive example which suggests that visite, gives you the most relevant context for what's happening in Africa i.e. It's viewed as a destination for a visit.

Overall, the inner products between q<3>q^{<3>} and all keys (kk) across the sequence gauge how relevant each word is to the context of 'Africa'. This relevance is then weighted and summed up with the corresponding value vectors (vv), leading to a context-aware representation of x<3>x^{<3>} as A<3>A^{<3>}.

Attention Mechanism Visualization

For the five attention based representations for these word, A<1>A<5>A^{<1>} \cdots A^{<5>}, we can write the equation as:

Attention(q,K,V)=softmax(QKTdk)VAttention(q, K, V) = softmax \left( \frac{Q \cdot K^T}{\sqrt{d_k}} \right) \cdot V

Step by step, what we will do is take the products between q<3>q^{<3>} and the other k<t>k^{<t>}'s and compute a Softmax over them. Then finally, we're going to take these Softmax values and multiply them with v<t>v^{<t>} and we take the element-wise sum of these for A<3>A^{<3>}, or more formally, A(q<3>,K,V)A(q^{<3>} , K, V).

💡

This shows that a word does not have a fixed representation and can actually adapt to how it is used in the sentence.

This needs to be done for every word. You can summarize all of these computations for all the words in the sequence by writing Attention(Q,K,V)Attention(Q, K, V) where QQ, KK, VV matrices have all these values:

Query (Q)Key (K)Value (V)q<1>k<1>v<1>q<2>k<2>v<2>q<3>k<3>v<3>q<4>k<4>v<4>q<5>k<5>v<5>\begin{align*} \text{Query (Q)} & \quad \text{Key (K)} & \quad \text{Value (V)} \\ q^{<1>} & \quad k^{<1>} & \quad v^{<1>} \\ q^{<2>} & \quad k^{<2>} & \quad v^{<2>} \\ q^{<3>} & \quad k^{<3>} & \quad v^{<3>} \\ q^{<4>} & \quad k^{<4>} & \quad v^{<4>} \\ q^{<5>} & \quad k^{<5>} & \quad v^{<5>} \end{align*}

The formula given above represents a vectorized form of the individual attention computations. The normalization term, dk\sqrt{d_k}, scales the dot products, preventing them from growing too large and exploding. Another name for this type of attention is the scaled dot-product attention.

Ultimately, this method offers a nuanced and adaptive representation for each word, far surpassing the limitations of fixed word embeddings by incorporating contextual information from the words on both sides of each word in the sequence.

Summary

Overall:

  • We feed XRL×dX \in R^{L \times d} to the network, where LL is the context window length and dd is the dimensions of the embedding.
  • We project XX into three matrices QQ, KK and VV:
    • Q=(WQXT)T=XWQTRL×dKQ = (W_QX^T)^T = XW_Q^T \in R^{L \times d_K}, where WQW_Q is a matrix of dimension dK×dd_K \times d.
    • K=(WKXT)T=XWKTRL×dKK = (W_KX^T)^T = XW_K^T \in R^{L \times d_K}, where WKW_K is a matrix of dimension dK×dd_K \times d.
    • V=(WVXT)T=XWVTRL×dVV = (W_VX^T)^T = XW_V^T \in R^{L \times d_V}, where WVW_V is a matrix of dimension dV×dd_V \times d.
  • We compute the attention using the following vectorized equation: A(Q,K,V)=softmax(QKTdK)VRL×dVA(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_K}})V \in R^{L \times d_V}
💡

WQW_Q and WKW_K need to have the same dimension since we take a dot product between QQ and KK. The output dimension depends on the dimension of WVW_V.

For a full picture of the dimensions:

self-attention-dim-summary

Multi-Head Attention

This process involves computing self-attention multiple times, with each computation known as a "head". Intuitively, we have multiple questions we'd like to find the best answer for.

The number of attention heads included in the attention layer varies from model to model, but numbers in the range of 12-100 are common. The intuition here is that each self-attention head will learn a different aspect of language, such as the relationship between the people entities in our sentence, the activity of the sentence or even other properties such as if the words rhyme.

While the implementation differs for efficiency reasons, there are essentially hh number of WQW_Q, WKW_K and WVW_V matrices, one for each question we'd like to answer. hh is called the number of heads.

Self-attention is computed with each of these matrices, to obtain an L×h×dVL \times h \times d_V matrix. The hh and dVd_V dimensions are concatenated to get an L×h.dVL \times h.d_V matrix. This is finally multiplied with an dO×h.dVd_O \times h.d_V matrix WOW_O to obtain the final output of dimension L×dOL \times d_O.

How Multi-Head Attention Works

The vectors QQ, KK, and VV are initially derived for each input term, by multiplying them with weight matrices WQW^Q, WKW^K, and WVW^V. In the context of multi-head attention, you take that same set of vectors as inputs, but are further processed with additional weight matrices to generate new sets of query, key, and value vectors:

W1Qq<1>,W1Kk<1>,W1Vv<1>W^Q_1 q^{<1>}, W^K_1 k^{<1>}, W^V_1 v^{<1>}

Multi Head Diagram

So with this computation, the word visite gives the best answer to "what's happening", which is why it is highlighted with this blue arrow to represent that the inner product between the key for visite has the highest value with the query for l'Afrique. You do the same for the remaining words so you end up with five vectors to represent the five words in the sequence.

Iterative Process of Multi-Head Attention

Rather than performing this operation just once, the transformer model repeats it multiple times, producing multiple "heads" of attention. Each head can be thought of as focusing on different features, or asking different questions, within the sentence, thus building a comprehensive representation.

For example, the model might use eight different sets of weight matrices WiQW^Q_i, WiKW^K_i, and WiVW^V_i to perform this computation eight times, resulting in eight different attention outputs, which just means performing this whole calculation maybe eight times:

MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(head_1, head_2, \ldots, head_h) W^O

where each individual head headi\text{head}_i is defined as:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW^{Q}_i, KW^{K}_i, VW^{V}_i)

and hh denotes the number of heads used in the process:

h=number of headsh = \text{number of heads}

Computation

The idea is to stack all the weight matrices required for computing the QQ, KK and VV matrices for each head into one single matrix. This ensures that we can obtain the QQ, KK and VV matrices using a single matrix multiplication instead of multiple multiplications.

Consider that each QQ, KK and VV matrix will have dQ=dK=dV=dhd_Q = d_K = d_V = d_h (say). Suppose we have hh number of heads. Thus, we need 3.h3.h number of dh×dd_h \times d matrices (33 for QQ, KK and VV, and hh for each head). In other words, we need a 3.h.dh×d3.h.d_h \times d matrix, where 3.h.dh3.h.d_h represents the stacked matrix dimension. Let this matrix be WW.

We then multiply XRL×dX \in R^{L \times d} with WW as follows:

QKV=(WXT)T=XWTRL×3.h.dh\text{QKV} = (WX^T)^T = XW^T \in R^{L \times 3.h.d_h}

We then reshape this to obtain an L×h×3.dhL \times h \times 3.d_h tensor. Finally, we can take chunks of three from the last dimension to obtain 33 L×h×dhL \times h \times d_h matrices, each representing the QQ, KK and VV matrices.

These three matrices are passed to the self-attention block to obtain an L×h×dhL \times h \times d_h output, which is concatenated along the last dimension to obtain the L×h.dhL \times h.d_h output AA. Finally, AA is multiplied with an h.dh×h.dhh.d_h \times h.d_h (dO=h.dhd_O = h.d_h) matrix WOW_O to obtain the final L×h.dhL \times h.d_h output as follows:

O=(WOAT)=AWOTRL×h.dh=dOO = (W_OA^T) = AW_O^T \in R^{L \times h.d_h = d_O}

In the actual implementation, we pass in three inputs:

  • dd - Input embedding size.
  • hh - Number of heads.
  • dOd_O - Expected output dimension of multi-headed attention.

From this, dhd_h is computed as dh=dOhd_h = \frac{d_O}{h} since dO=h.dhd_O = h.d_h. In other words, dOd_O should be such that dOmodh=0d_O \bmod h = 0. The rest is the same as above.

In PyTorch, this is implemented as follows:

import math
 
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    # Applying mask for masked multi-headed attention
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention
 
 
def expand_mask(mask):
    assert (
        mask.ndim > 2
    ), "Mask must be at least 2-dimensional with seq_length x seq_length"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask
 
 
class MultiheadAttention(nn.Module):
    def __init__(self, d, dO, h):
        super().__init__()
        assert dO % h == 0, "Embedding dimension must be 0 modulo number of heads."
 
        self.dO = dO
        self.h = h
        # Compute dh
        self.dh = dO // h
 
        # Create the stacked weight matrix using a linear layer
        # It will receive a d-dim input
        # And produce a 3.h.dh = 3.dO dimensional output
        self.qkv_proj = nn.Linear(d, 3 * dO)
 
        # Create WO using a linear layer
        # It will receive an h.dh = dO-dim input and
        # Produce a dO-dim output
        self.o_proj = nn.Linear(dO, dO)
 
        self._reset_parameters()
 
    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)
 
    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
 
        if mask is not None:
            mask = expand_mask(mask)
 
        # [Batch, L, 3*h.dh] = [Batch, L, 3.dO]
        qkv = self.qkv_proj(x)
 
        # Reshape to [Batch, L, h, 3*dh]
        qkv = qkv.reshape(batch_size, seq_length, self.h, 3 * self.dh)
        # Permute as [Batch, h, L, 3*dh]
        qkv = qkv.permute(0, 2, 1, 3)
        # Take out [Batch, h, L, dh] chunks to obtain Q, K and V
        q, k, v = qkv.chunk(3, dim=-1)
 
        # Apply self-attention - [Batch, h, L, dh]
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        # Permute to [Batch, L, h, dh]
        values = values.permute(0, 2, 1, 3)
        # Concatenate to [Batch, L, h.dh] = [Batch, L, dO]
        values = values.reshape(batch_size, seq_length, self.dO)
        # Multiply with WO for final output
        o = self.o_proj(values)
 
  return (o, attention) if return_attention else o

Parallel Computation Advantage

A key advantage of this design is that each head operates independently, allowing for several attention layers to run runin parallel. This parallelism facilitates efficiency and speeds up the training and inference processes compared to a sequential for-loop approach. Once all heads have been computed, their outputs are concatenated and then transformed by multiplying with the weight matrix WOW^O to produce the final multi-head attention output.

Multi Head Attention

This approach enables the transformer model to capture a more nuanced understanding of the input sequence, making it highly effective for complex tasks in NLP.

Transformer Network

This is the simplified overall structure of the encoder (left) and decoder (right) of the Transformer model, and you stack this layer NN times. In one layer of Transformer, there are three multi-head attention, which are displayed as boxes in orange.

Transformer Architecture

To address sequence translation tasks, the model incorporates tokens to signify the start (<SOS>) and end (<EOS>) of sentences, crucial to define sequence boundaries.

Janevisitel'Afriqueenseptembre
<SOS>x<1>x^{<1>}x<2>x^{<2>}x<3>x^{<3>}x<4>x^{<4>}x<5>x^{<5>}<EOS>

The first step in the transformer is these embeddings get fed into an encoder block which has a multi head attention layer. This stage utilizes matrices QQ, KK, and VV, derived from embeddings and corresponding weight matrices WW.

Encoder

This layer then produces a matrix that is passed into a feed forward neural network, which helps determine what interesting features there are in the sentence. In the transformer paper this encoding block is repeated NN times, typically with N=6N=6.

Decoder

After NN times through the encoder block, we will then feed the encoder output into a decoder block. The decoders block's job is output the English translation. The decoder initiates with the <SOS> token.

The token gets fed into this first multi-head attention block. Just this one <SOS> token is used to compute QQ, KK and VV for this first multi-headed attention block in the decoder.

Encoder & Decoder

The first multi-head attention output within the decoder generates the QQ matrix for the subsequent attention block, while KK and VV come straight from the encoder's output. This structured design ensures that the decoding process leverages the context from both the already translated sequence and the encoder's representation of the source sentence.

Why is it structured this way? One piece of intuition that could be taken is that that the input at the bottom of the decoder is what you have translated of the sentence so far. It will then pull context from KK and VV, which is translated from the French version of the sentence, to then try to decide what is the next word in the sequence to generate.

To finish the description of the decoder block, the multi-head attention block outputs the values which are fed to feed forward neural network. This decoder block is also repeated N times, maybe six times (where you take the output and feed it back to the input).

The job of this new network is to predict the next word in the sentence. So hopefully it will decide that the first word in the English translation is Jane and what we do is then feed Jane to the input as well. The next query comes from <SOS> and Jane and it says what is the most appropriate next word using the KK and VV context from the encoder.

But beyond these main ideas, there are a few extra bells and whistles missing, such as positional encoding, add & norm, and masked multi-head attention.

Positional Encoding

A unique aspect of the transformer model is its use of positional encoding, necessary due to the self-attention mechanism's insensitivity to word order. The position within the sentence can be extremely important to translation. The position is encoded using the following function:

PE(t,i)={sin(t100002k/d),i=2kcos(t100002k/d),i=2k+1\text{PE}(t, i) = \begin{cases} \sin(\frac{t}{10000^{2k/d}}), & i = 2k \\ \cos(\frac{t}{10000^{2k/d}}), & i = 2k + 1 \\ \end{cases}

where tt (1tL1 \leq t \leq L) is the numerical position of the word being encoded and ii (0i<d0 \leq i < d) is an index into the embedding for the word.

Take the sentence:

Jane visite l'Afrique en Septembre

Consider the word Jane. Here, t=1t = 1. Assuming 4-dimensional word embedding (d=4d = 4) with (0i30 \leq i \leq 3), the positional encoding of Jane would be:

[sin(1),cos(1),sin(11000)),cos(110000)][\sin(1), \cos(1), \sin(\frac{1}{\sqrt{1000})}), \cos(\frac{1}{\sqrt{10000}})]

This creates a vector with alternating sine and cosine waves, which have different frequencies. This vector is added to the original embedding for the word so that the original embedding also has information about the word's position.

Each word x<t>x^{<t>} is paired with a positional embedding vector p<t>p^{<t>} of the same dimensionality, capturing its unique position within the sentence.

A 4 dimensional vector like this:

4 Dim Vector

Exists for all words x<t>x^{<t>}. In this example, we're going to then create a positional embedded vector of the same dimension for all words x<t>x^{<t>} and call this positional embedding p<t>p^{<t>}. In the equation, tt denotes the numerical position of the word. So, for the word Jane, tt is equal to 1. ii refers to the different dimensions of the encoding i.e:

Annotated Vector Diagram

What position encoding does with sine and cosine, is create a unique position encoding vector for each word. So, the vector p<3>p^{<3>} that encodes the position of l'Afrique, the third word, will be a set of four distinct values. They'll be different to the four values used in the positional encoding of the position of the first word Jane.

Transformer Details

Positional encodings p<t>p^{<t>} are directly added to the input embeddings x<t>x^{<t>}, so that each of the word vectors is also influenced by where in the sentence the word appears.

The output of the encoding block contains contextual semantic embedding and positional encoding information. The outputs of the embedding layer is then dd, which in this case 4, by the maximum length the sequence can take. The outputs of all these layers in the encoder and decoder are also of this shape.

In addition to adding these position encodings to the embeddings, you'd also pass them through the network with residual connections, similar to those you see in ResNet. Their purpose in this case is to pass along positional information through the entire architecture.

Residual Connections

The output matrix of the multi-headed attention is added to the original embedded input using a residual connection.

💡

This requires the output dimension of the multi-headed attention layer to match the original dimension of the input. In other words, dO=dd_O = d so that the output is L×dL \times d.

This residual connection is important since:

  • It helps with the depth of the model by allowing information to be passed across greater depths.
  • Multi-headed attention does not have any information about the position of tokens in the input sequence. With the residual connection (and [[#Positional Encoding|positional encoding]]), it is possible to pass this information to the rest of the model instead of the information being lost after the first multi-headed attention pass. It gives the model a chance to distinguish which information came from which element of the input sequence.

Add & Norm

Similar to batch normalization, the Add & Norm layer feature facilitates the passing positional information throughout the network. This is preferred over batch normalization since the batch size is often small and batch normalization tends to perform poorly with text since words tend to have a high variance (due to rare words being considered for a good distribution estimate).

Add & Norm

This batch-norm like layer (Add & Norm layer) is repeated throughout this architecture. Layer normalization:

  • Speeds up training.
  • Provides a small regularization.
  • Ensures features are in a similar magnitude among the elements in the input sequence.

Feed-Forward Neural Network (FFNN)

The normalized output is fed to an FFNN. It is applied to each element in the input sequence separately and identically. The encoder uses a Linear \rightarrow ReLU \rightarrow Linear model. Usually, the inner dimension of the FFNN is 2-8 ×\times larger than dd, the size of the input embedding.

The FFNN adds complexity to the model and can be thought of as an extra step of "pre-processing" applied on the output of multi-headed attention.

There is also a residual connection between the output of the multi-headed attention and the output of the FFNN, with layer normalization.

Masked Multi-Head Attention

So previously we've looked through how the transformer performs prediction, one word at the time, but how does it train? Masked Multi-Head Attention is crucial during training - this component ensures the model cannot prematurely access future words in the sequence, simulating the conditions of sequence prediction.

Masked Multi-Head Attention

During the training process where you're using a data set of correct French to English translations to train your transformer, you'd need to mask these words.

Let's say the data set has the correct French to English translation, "Jane Visite l'Afrique en Septembre" and "Jane visits Africa in September". When training you have access to the entire correct English translation, the correct output, and the respective correct input. And because you have the full correct output you don't have to generate the words one at a time during training.

Instead, what masking does is it blocks out the last part of the sentence to mimic what the network will need to do at test time/during prediction. In other words, all that mask multi-head attention does is repeatedly pretends that the network had perfectly translated say the first few words and hides the remaining words to see if given a perfect first part of the translation, whether the neural network can predict the next word in the sequence accurately.

In detail, a part of the L×LL \times L matrix that is obtained after computing QKTdK\frac{QK^T}{\sqrt{d_K}} is masked. In particular, a word at index ii should only attend to words from indices 11 to ii. Thus, all indices from i+1i + 1 to LL are set to -\infty so that they become 00 when the softmax is applied. Consider the sentence:

Je suis un étudiant

After computing QKTdK\frac{QK^T}{\sqrt{d_K}} and applying softmax, the matrix might look like this:

Jesuisunétudiant
Je11000000
suis0.020.020.980.980000
un0.050.050.200.200.750.7500
étudiant0.380.380.020.020.050.050.550.55

For suis, i=2i = 2. Thus, the elements at indices (2,3)(2, 3) and (2,4)(2, 4) are 00 since suis should attend only to Je and suis itself. On the other hand, the word étudiant should attend to all the words since its the last word. Thus, no element in that row is 00.

Note that the output is shifted to the right during inference, where we do not have the entire output sequence and are actually predicting one word at a time. We start the decoder with the single token <SOS> as the input and then, as the decoder predicts the next word, we add this new word to the input. This is what's referred to as "shifted right" in the diagram.

Cross-Attention

Cross-attention is a generalization of self-attention. In self-attention, we are computing the attention matrix for a single sequence. The same sequence is used to compute the QQ, KK and VV matrices.

In cross-attention, we have two different sequences x1x_1 and x2x_2. x1x_1 is used to compute the QQ matrix while x2x_2 is used to compute the KK and VV matrices. When x1=x2x_1 = x_2, cross-attention reduces to self-attention.

In the decoder, cross-attention is used in the second multi-headed attention layer. The QQ matrix comes from the output of the first multi-headed attention layer, which is turn uses the input to the decoder as its input. The KK and VV matrices are computed from the output of the final encoder block.

In essence, the input to the decoder is acting as x1x_1 and the output of the final encoder block is acting as x2x_2.

Cross-attention can work with sequences of different lengths. When computing QQ from x1x_1, we get an L1×dQL_1 \times d_Q matrix. When computing KK from x2x_2, we get an L2×dQL_2 \times d_Q matrix. When we take the dot product of QQ and KK, we get an L1×L2L_1 \times L_2 matrix. Since VV is also computed from x2x_2, its dimension is L2×dVL_2 \times d_V. Thus, the overall result of cross-attention will be L1×dVL_1 \times d_V after softmax and multiplying with VV. See the image below (L1=nL_1 = n and L2=mL_2 = m):

cross-attention-dim-summary

💡

This technique of cross-attention is also used in diffusion models. See High-Resolution Image Synthesis with Latent Diffusion Models (opens in a new tab).

Linear, Softmax & Prediction

There is a final linear layer followed by a softmax activation. This converts the output of the decoder to a probability distribution over all the words. In other words, if the dictionary of words has NN words, this layer has NN units, for each word.

The next word can be predicted from this probability distribution by either taking the one with the maximum probability or using other techniques. These other techniques can affect how creative the model is. For example, technique might lead to the model not choosing the most "obvious" word every time and going for slightly eccentric choices.

Greedy Sampling

The technique of using the word with the maximum probability is called greedy sampling. This is the most commonly used technique for many models. Consider the following softmax output:

ProbabilityWord
0.200.20cake
0.100.10donut
0.020.02banana
0.010.01apple
\dots\dots

The model would output the word cake since it has the highest probability.

Random Sampling

Another approach is called random-weighted sampling. It introduces some variability to the model's output. Instead of taking the word with the maximum probability, the probabilities are used as weights to sample one word at random. For example, consider the following softmax output:

ProbabilityWord
0.200.20cake
0.100.10donut
0.020.02banana
0.010.01apple
\dots\dots

The word cake has a 20% chance of being selected while the word banana has a 2% chance of being selected. It might be that the model selects the word banana.

It is possible that this technique leads to the model becoming too creative, where it generates words or wanders into topics that do not make sense with respect to the prompt.

Types of Configurations of Transformers

Encoder-only Models

They only have encoders. Without some changes, these models always produce an output which is of the same length as the input sequence. It is possible to modify these so that they can be used for tasks such as sentiment analysis. Examples of such models are BERT.

Encoder-Decoder Models

This is the model originally described in the Transformers paper and the one detailed here. The output sequence and the input sequence can be of different lengths. It is useful for sequence-to-sequence tasks such as machine translation. Examples of such models are BART and FLAN-T5.

Decoder-only Models

These are some of the most commonly used models. As they have scaled, they have gained the ability to generalize to pretty much any task. Examples of such models are GPT, BLOOM, LLaMA.


Resources: