Attention Mechanism, Core of Transformer Models
In this blog post, I will focus on the core principles of transformer models, specifically the self-attention mechanism. To keep the discussion straightforward, I will approach the concepts from the perspective of decoder-only models like GPT.
Generative Pretrained Transformer
-
GPT (Generative Pretrained Transformer) is a decoder-only transformer model specifically designed for text generation. It functions as an autoregressive model, generating one token at a time by predicting the next token based on the sequence of tokens it has already processed.
-
For example, if the model receives the input “cats climb trees during”, it would generate and predict “storms” as the next token.
-
Thus, the input chunk during training would be [“cats”, “climbs”, “trees”, “during”], and the corresponding target chunk would be [“climbs”, “trees”, “during”, “storms”]. Notice how the target is shifted one position to the right relative to the input.
Assumptions and Simplifications
Before diving into the details, here are a few assumptions made for simplicity and to focus on the key points of the blog.
- Tokenization: Machines cannot process raw text directly, so text is converted into tokens (numerical representations) through a process called tokenization.
- For simplicity, in this article, I will refer to words like “cats” and “climb” as tokens to make explanations easier to follow, but in practice, tokens are integers that models process.
- I will use the input chunk ["cats", "climbs", "trees", "during"] as an example to explain the process. When this chunk is passed to the model, it is expected to generate the next token: ["storms"].
- A language model learns an embedding table that contains token embeddings for every token in the training dataset. For instance, the embedding for “cat” is represented as emb(“cat”), and the embedding for “trees” is represented as emb(“trees”).
The Necessity of Understanding Relationships Between Tokens
-
The core and most critical property required for a generative model is its ability to understand the relationships between tokens (text). To achieve this, there must be a communication mechanism that enables tokens to interact and share information.
-
Let’s say [cats, climb, trees, during] is an input chunk provided to the generative model during training. The model must have the ability to communicate and understand the relationships among the tokens within the input chunk. Additionally, since generative models are autoregressive by nature, a token can only receive information from preceding tokens, not from future tokens. For example, the token trees will only gather information from cats and climb.
-
So within one input chunk, there are multiple examples for the model to learn. Here’s how it works:
- For each token in the input chunk, the model learns to predict the next token based on all the preceding tokens (context).
- This effectively creates multiple training examples within a single chunk.
For example, given the input chunk: [cats, climb, trees, during, storms]
The model can create the following examples:
1. Input: [cats] → Target: climb
2. Input: [cats, climb] → Target: trees
3. Input: [cats, climb, trees] → Target: during
4. Input: [cats, climb, trees, during] → Target: storms
To establish a mathematical relationship between the input and output tokens, we can write it as follows:
Alternatively, I can expand and represent all such relationships in matrix form, as shown below, with the leftmost matrix referred to as the “attention weights matrix":
Attention weights
Imagine each token in a sentence trying to decide how much importance it should give to the tokens that came before it. The weight matrix acts as a guide, reflecting the amount of attention each token should pay to its past tokens. But how does it figure this out? That’s where self-attention comes in—it’s like a thoughtful process where each token evaluates its relationship with the others and determines these attention weights, ensuring the context is captured effectively.
Before we dive into self-attention, there’s one important tweak to the weight matrix:
Applying Causal Mask to the Attention Weights
- To ensure future context is blocked, the elements in the upper triangle of the matrix need to be zero. This is crucial for training autoregressive models, as it prevents tokens from seeing future information.
- During training, this is achieved by applying a causal mask that “hides” these future weights, ensuring the model learns to predict based only on past tokens.
So far, we have discussed how the relationships among tokens within an input sequence can be mathematically represented using an attention weight matrix. Next, we will explore a straightforward method to formulate and compute the attention weight matrix.
Simple Average Self Attention: A Conceptual Approach
Let’s assume we want a model where, for any token, the information from previous tokens is represented as a simple average.
To achieve this, we could begin with all-ones matrix and then apply the causal mask as discussed earlier:
Now, if we normalize the “Weights Masked” matrix along the rows, we obtain a matrix that can be used effectively for simple averaging. This normalization can be achieved using the softmax function.
Code Snippet For Simple Average Self Attention
Below is a code snippet demonstrating how to apply the causal mask and softmax. The process is straightforward, and for additional details, you can refer to the PyTorch documentation if needed.
class SimpleAverageSelfAttention(nn.Module):
"""
Compute a simple average of past tokens for communicating information to a token, which doesn’t account for the varying importance of different tokens.
"""
def __init__(self, config):
super().__init__()
# use register_buffer to add a non-trainable constant which do not change with training
chunk_size = config["data"]["chunk_size"]
batch_size = config["data"]["batch_size"]
self.register_buffer('causal_mask', torch.tril(torch.ones(chunk_size, chunk_size)))
self.register_buffer('attention_weights', torch.zeros((batch_size, chunk_size, chunk_size)))
def forward(self, x):
# applying the causal mask for autoregressive nature
attention_weights = self.attention_weights.masked_fill(self.causal_mask==0, value=-float('inf'))
attention_weights = F.softmax(attention_weights, dim=-1) #normalization
attention_output = attention_weights @ x
return attention_output
Limitations of the Simple Average Self-Attention Approach
- The weights are uniform along each row. While information from past tokens is passed to predict the next token, all past tokens are given equal importance.
- The weights are fixed and non-trainable, limiting the model’s ability to adapt.
- The importance of past tokens should not be uniform. Instead, it should be determined by factors such as what the current token is focusing on, the content of the past tokens, and the value they can contribute. This dynamic weighting is learned from the training data, and it is precisely what the Self-Attention mechanism in transformers is designed to address!
The Self-Attention Head of a Transformer
Database Analogy
Before diving into self-attention, let’s imagine a simple scenario involving a tabular database.
-
Picture this: you’re working with a database that has two important columns—key and value. The "key" act like labels or tags that identify the type of information stored in the "value" column. Now, let’s say you’re the user, and you’re looking for some specific information. You write down your request, a query, which in this case is something like: “Give me the value corresponding to key2.”
-
The database hears your request and starts searching. It looks at all the keys in its rows, trying to match your query (key = key2) with one of its stored keys. Once it finds a match, it goes to the corresponding row and retrieves the value stored there, handing it over to you.
It’s a simple and straightforward process—just matching a query to a key and pulling out the value. This basic idea forms the foundation for something much more powerful in self-attention, where we’ll see how this retrieval concept is taken to the next level.
Self-Attention Intuition
Self-attention works on a very similar principle to the database example, but let’s imagine it as a conversation between tokens.
-
Picture this: every token in a sequence is like a person at a party, trying to gather useful information from others. Let’s focus on Token 1—it has something specific it’s looking for. To communicate this, Token 1 sends out a query vector, like asking, “Who here has the information I need?”
-
Now, every other token at the party, including Token 2, has a key vector that represents what they know or what they’re about. Token 1 goes around the party, comparing its query vector with the key vectors of all the other tokens. If the dot product (essentially, how similar they are) between Token 1’s query and Token 2’s key is high, it means Token 2 has relevant information for Token 1.
-
Token 2, seeing that it has something useful, then offers Token 1 its value vector, which is the actual information it can provide. Token 1 takes this value and combines it with similar contributions from other tokens, weighted by how relevant they are, to gather all the context it needs for its task. And just like that, the tokens communicate dynamically to share what’s most important, building a deeper understanding of the sequence as a whole.
Self-Attention Mechanism
Now that we’ve grasped the concept of self-attention intuitively, let’s delve into the mathematics behind it. The process can be broken down into a few key steps:
1. Representing Tokens as Embeddings
- Each word in the sentence is first converted into a vector representation using an embedding layer. For simplicity, let’s denote the embeddings as:
- These embeddings form the input matrix:
\(X = [x_{\text{cats}}, \quad x_{\text{climb}}, \quad x_{\text{trees}}, \quad x_{\text{during}}]\)
2. Creating Query, Key, and Value Vectors
For each token, we compute three vectors:
1. Query (Q): What the token is looking for in other words.
2. Key (K): What the token represents or contains.
3. Value (V): The information the token provides.
-
These vectors are obtained by multiplying the token's embedding with learned weight matrices:
\(Q = XW^Q, \quad K = XW^K, \quad V = XW^V\) -
Length of Query, Key, and Value vectors:
The length (or dimensionality) of these vectors is determined by the head size, which is the number of hidden dimensions allocated to an attention head.
What is head size of an attention head?
The head size is the number of dimensions used for each attention head to compute the attention mechanism, enabling the model to focus on different aspects of the sequence.
3. Computing the "Scaled Dot-Product Attention"
- Let’s say we want to compute how much “climb” should focus on “cats”. This involves calculating the dot product between the query vector of “climb” and the key vector of “cats”:
- If the query and key vectors have unit variance, their dot product will not have unit variance. Instead, the variance of the dot product will be directly proportional to the dimensionality of the query and key vectors.
-
For instance, even though the query and key vectors in the example below have unit variance, their dot product results in a value of the order of the head size. To ensure the dot product maintains a consistent variance, we normalize its magnitude by scaling it.
-
By scaling with sqrt(head_size), we normalize the variance of the dot product to approximately 1, ensuring the values remain in a reasonable range.
Scaled Dot Product Attention weights
The scaled dot product for the above single pair of tokens, which represents the attention weight for (“climb”“cats”), can be generalized to compute all possible combinations of scaled dot products between tokens in a sequence. This can be expressed in matrix form, where the resulting matrix has a shape of (chunk_size, chunk_size), with chunk_size representing the sequence length.
The scaled dot product is essentially the attention weight, representing how much attention the token “climb” should pay to the token “cats.” By calculating attention weights for all pairs of tokens in the input chunk, we can construct an attention weights matrix.
4. Enforcing causality and applying Softmax for normalization
-
As we observed earlier, we mask out the upper triangular elements to 0 to prevent tokens from accessing information from future tokens. This masking is crucial when working with self-attention in a decoder-only model. After applying the mask, we perform the Softmax operation along the rows of the matrix to normalize the values.
-
Normalization ensures that the attention weights for each token sum to 1, making them interpretable as probabilities. This allows the model to determine how much focus each token should give to the others in a fair and balanced way, while respecting the constraints imposed by the mask.
5. Arriving at the Attention Output
- To compute the final attention output, the attention weights are multiplied by the Value (V) matrix. This operation integrates the contextual information provided by the values, weighted by the relevance determined by the attention weights.
Code snippet for Single Head Self-Attention
Below is a code snippet that demonstrates a single self-attention head, implementing the exact steps discussed above.
class SingleSelfAttentionHead(nn.Module):
""" one head of self-attention"""
def __init__(self, chunk_size, token_emb_dim, head_size):
super().__init__()
self.head_size = head_size
self.key = nn.Linear(token_emb_dim, head_size, bias=False)
self.query = nn.Linear(token_emb_dim, head_size, bias=False)
self.value = nn.Linear(token_emb_dim, head_size, bias=False)
self.register_buffer("causal_mask", torch.tril(torch.ones(chunk_size, chunk_size)))
def forward(self, x):
batch_size, chunk_size, token_emb_dim = x.shape
key_vector = self.key(x) # size will (batch_size, chunk_size, head_size)
query_vector = self.query(x) # size will (batch_size, chunk_size, head_size)
# compute attention weights
attention_weights = query_vector @ key_vector.transpose(-2,-1) # size will be (batch_size, chunk_size, chunk_size)
attention_weights = attention_weights * (self.head_size**-0.5)
attention_weights = attention_weights.masked_fill(self.causal_mask==0, -float('inf'))
attention_weights = F.softmax(attention_weights, dim=-1)
value_vector = self.value(x) # size will (batch_size, chunk_size, head_size)
contextualised_embeddings = attention_weights @ value_vector # shape will be (batch_size, chunk_size, head_size)
return contextualised_embeddings
Final Thoughts
In this blog, I aimed to provide a detailed explanation of the single-head self-attention mechanism used in transformer models. In the original Transformers paper, you’ll find that each transformer block contains multiple self-attention heads. However, this blog focused specifically on introducing and explaining the self-attention mechanism in depth.
In a future blog, I plan to cover the remaining components of the transformer architecture. Meanwhile, if you’re interested, I already have a code implementation that trains a transformer from scratch. You can find it here.