A few weeks ago I have implemented Self Attention mechanism from scratch using PyTorch and this post is a sequel to the first one.
If not already, do read that article from here Building Self-Attention from scratch.
In this post we will explore various steps involved in building Multi-Head Attention, how it’s different from Self-Attention and why it’s needed.
What is a Multi-Head Attention
In simple words, Multi-Head Attention is an extension of Self-Attention but the main idea here is to use the Self-Attention multiple times in parallel on a same input sequence to understand the hidden intricate relationships.
You might argue this is a duplication and unnecessary computational overhead. Well, not really let me explain with an analogy.
Foundational Economics book The Wealth of Nations has a section called Division of Labor, the concept is very simple, if we divide any work into small sub tasks and assign a labor to it, then productivity will improve significantly.
If you think about it, when a person is focused on a small sub task he will eventually gain specialization, as he don’t need to worry about other steps in the process, in fact this is how our modern factories operate today. The same idea can be applied with Multi-Head Attention.
Instead of using one Attention head for everything why not use multiple heads so that the model learning process will be efficient.
Let’s say if we have two heads, one head can focus on Grammar and the second one on Language structure, as a result they can learn independently and efficiently (remember -division of labor).
We can think of this as, in a garment factory one person is assigned to cutting and the next person for sorting and then for Sewing and finally finishing & packaging.
The first person no need to worry about packaging so he will sharpen his skills on cutting, as a result he will achieve precise cuts in a fraction of time.
Prepare the ground
For simplicity, I’ll use the same inputs as before, the shape of this input is 9 x 3 —> 9 tokens and each token has 3 dimensions.
inp = torch.tensor([[0.43, 0.15, 0.89], # Boy
[0.55, 0.87, 0.66], # is
[0.57, 0.85, 0.64], # crying
[0.22, 0.58, 0.33], # because
[0.77, 0.25, 0.10], # he
[0.05, 0.80, 0.55], # wants
[0.02, 0.30, 0.47], # an
[0.47, 0.67, 0.64], # ice
[0.77, 0.33, 0.70]] # cream
Since we work with batches, lets duplicate the same input two times and create a batch. The output shape of this batch is 2 x 9 x 3 —> 2 batches, 9 tokens in each batch and 3 dimensions per token.
inp_batch = torch.stack([inp,inp],dim=0)
print(inp_batch)
print(inp_batch.shape)
# -----------result--------------
tensor([[[0.4300, 0.1500, 0.8900],
[0.5500, 0.8700, 0.6600],
[0.5700, 0.8500, 0.6400],
[0.2200, 0.5800, 0.3300],
[0.7700, 0.2500, 0.1000],
[0.0500, 0.8000, 0.5500],
[0.0200, 0.3000, 0.4700],
[0.4700, 0.6700, 0.6400],
[0.7700, 0.3300, 0.7000]],
[[0.4300, 0.1500, 0.8900],
[0.5500, 0.8700, 0.6600],
[0.5700, 0.8500, 0.6400],
[0.2200, 0.5800, 0.3300],
[0.7700, 0.2500, 0.1000],
[0.0500, 0.8000, 0.5500],
[0.0200, 0.3000, 0.4700],
[0.4700, 0.6700, 0.6400],
[0.7700, 0.3300, 0.7000]]])
torch.Size([2, 9, 3])
We are now ready with our inputs, so lets move on to the next section where we will implement Multi-Head Attention class.
Simple Attention
To keep things simple, here is an Attention block without Masking & Dropout, we will improvise this as we go.
class SimpleAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
b, num_tokens, d_in = x.shape
k = self.W_key(x)
q = self.W_query(x)
v = self.W_value(x)
attn_scores = q @ k.transpose(1,2)
attn_weights = torch.softmax(attn_scores / k.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ v
return context_vec
This is a single head Attention class, if we initialize this class and pass on our inputs then the results will be as below.
torch.manual_seed(123)
b, context_len, d_in = inp_batch.shape
SA = SimpleAttention(d_in=d_in, d_out=2, context_length=context_len)
res = SA(inp_batch)
print(res)
print(res.shape)
#----------------result------------------
tensor([[[-0.5221, -0.0740],
[-0.5193, -0.0762],
[-0.5194, -0.0762],
[-0.5159, -0.0756],
[-0.5180, -0.0750],
[-0.5160, -0.0760],
[-0.5165, -0.0747],
[-0.5193, -0.0757],
[-0.5221, -0.0747]],
[[-0.5221, -0.0740],
[-0.5193, -0.0762],
[-0.5194, -0.0762],
[-0.5159, -0.0756],
[-0.5180, -0.0750],
[-0.5160, -0.0760],
[-0.5165, -0.0747],
[-0.5193, -0.0757],
[-0.5221, -0.0747]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([2, 9, 2])
The output shape is 2 x 9 x 2 —> 2 batches, 9 tokens and 2 dimensions per token. If you notice, the output token dimensions are controlled by d_out variable (keep that in mind, I’ll come back to it).
Simple Multi-Head Attention
We can now create Multiple Heads using the above class.
The idea is simple, we will create multiple Attention modules iteratively and store them in a list, we will use nn.ModuleList to achieve this.
And in our forward pass we will iterate our Module list and pass on our inputs to each module and then concatenate outputs using torch.cat function at the last dimension.
class MHA_Simple(nn.Module):
def __init__(self, n_heads, d_in, d_out, context_len, qkv_bias=False):
super().__init__()
self.heads = nn.ModuleList([SimpleAttention(d_in=d_in, d_out=d_out, context_length=context_len, qkv_bias=qkv_bias)for _ in range(n_heads)])
def forward(self, x):
res = torch.cat([head(x) for head in self.heads],dim=-1)
return res
If we initialize this class and pass on our inputs, the results will be as below.
torch.manual_seed(123)
b, context_len, d_in = inp_batch.shape
mha_simple = MHA_Simple(n_heads=2, d_in=d_in, d_out=2, context_len=context_len)
res = mha_simple(inp_batch)
print(res)
print(res.shape)
#----------------result------------------
tensor([[[-0.5221, -0.0740, 0.5001, 0.3234],
[-0.5193, -0.0762, 0.4999, 0.3235],
[-0.5194, -0.0762, 0.4999, 0.3233],
[-0.5159, -0.0756, 0.4985, 0.3198],
[-0.5180, -0.0750, 0.4988, 0.3178],
[-0.5160, -0.0760, 0.4988, 0.3219],
[-0.5165, -0.0747, 0.4984, 0.3201],
[-0.5193, -0.0757, 0.4997, 0.3228],
[-0.5221, -0.0747, 0.5003, 0.3228]],
[[-0.5221, -0.0740, 0.5001, 0.3234],
[-0.5193, -0.0762, 0.4999, 0.3235],
[-0.5194, -0.0762, 0.4999, 0.3233],
[-0.5159, -0.0756, 0.4985, 0.3198],
[-0.5180, -0.0750, 0.4988, 0.3178],
[-0.5160, -0.0760, 0.4988, 0.3219],
[-0.5165, -0.0747, 0.4984, 0.3201],
[-0.5193, -0.0757, 0.4997, 0.3228],
[-0.5221, -0.0747, 0.5003, 0.3228]]], grad_fn=<CatBackward0>)
torch.Size([2, 9, 4])
Let’s unpack:
- First point, the output dimension has changed to
2 x 9 x 4i.e. 2 batches, 9 tokens and 4 dimensions per token - But if you notice our
d_out = 2then why are we getting 4 dimension output ? well, we can think of this as each head is producing 2 output dimensions per token - Since we have two Attention heads, the outcome will be 2 + 2 which is 4 dimensions per token
- Let’s say, if we initialize this module again with three heads then our output dimensions will be
2 x 9 x 6—> (2 + 2 + 2)
In case If you didn’t notice, the first two column outputs are matching with our previous single head Attention outputs.
We can already call this as Multi-Head Attention but there is a big drawback here. As we have stored number of heads in a list, we have to call each head separately and then run it through forward pass first and then concatenate the results.
This entire process is happening sequentially i.e. one step after the other. As a result this method will be computationally very inefficient.
So, let’s take this concept forward and create a more efficient Multi-Head Attention class.
Multi-Head Attention from scratch
The core idea is same as our simple Multi-Head Attention class but since we will deal with lot of data and lot of computations we need to make it efficient.
I remember Andrej Karpathy once said the core idea of an algorithm will be simple but most of the code around it is an optimization (paraphrased), this is especially true with what’s coming next.
Step 1:
First simplification I’d like to propose is to initialize a single Linear layer (instead of three separate layers) and then split the layer as below.
d_in = 3
d_out = 2
qkv = nn.Linear(d_in, d_out*3)
print(qkv.weight.T.shape)
#----------------result------------------
torch.Size([3, 6])
q,k,v = torch.chunk(input=qkv(inp_batch), chunks=3, dim=-1)
print(q.shape)
#----------------result------------------
torch.Size([2, 9, 2])
Let’s unpack:
- We have initiated one
nn.Linearlayer with d_in = 3 which is ourinp_batchEmbedding dimension - Then we need 2 output dimensions per q, k, v hence we have multiplied
d_out * 3this will result in overall 2 * 3 = 6 output dimensions (the shape of qkv is 3 x 6) - Then we have transformed our qkv layer by passing through input batch like
qkv(inp_batch) - And finally we have used
torch.chunkto split our transformed matrix into three parts and assigned them to q, k, v, variables.
So, the final shape of each variable is 2 x 9 x 2 —> 2 batches, 9 tokens and 2 dimensions per token, this will be our final outcome even if we initiate three separate Linear layers.
Step 2:
The next simplification is a concept of Head Dimension.
In our simple Multi-Head Attention we have separately initiated the entire Attention class per head, so as a result when our d_out = 2 then for a two head module the outcome was 4 dimensional.
With this approach we need to initialize our qkv layers per head separately and then calculate outputs separately hence this will be a potential bottleneck.
But we can take an alternative approach, what if we just divide already available d_out dimensions by number of heads.
Let’s say our d_out = 2 and we want two Attention heads, so we can just divide d_out / num_heads —> 2 / 2 = 1, this will result in each head with an equal number of dimensions.
But there is a catch here, the d_out must be clearly dividable by number of heads because we cannot assign 0.5 dimension to one head and 1.5 dimension to the other head.
Few examples:
- d_out = 2, n_heads = 2 —> 2/2 = 1 (1 dimension per head)
- d_out = 4, n_heads = 2 —> 4/2 = 2 (2 dimension per head)
- d_out = 6, n_heads = 3 —> 6/3 = 2 (2 dimension per head)
- d_out = 6, n_heads = 2 —> 6/2 = 3 (3 dimension per head)
Step 3:
The third improvement (notice I’m not saying simplification) is to work on Tensor reshaping operation to align our inputs for efficient computation.
For Tensor reshaping we will use tensor.view function and let’s implement step by step.
For simplicity, I’ll not use separate Query, Key & Value tensors as I’m going to unpack the entire model into small & simple steps.
I have initiated a random input and assigned it to a variable calledinp as below.
torch.manual_seed(123)
inp = torch.rand(2,5,2) # 2 batches, 5 tokens, 2 dimensions per token
print(inp.shape)
#----------------result------------------
torch.Size([2, 5, 2])
batch, tokens, d_in = inp.shape
d_out = 2 # define output dimension
n_heads = 2 # let's say we want two heads
head_dim = d_out // n_heads # floor division --> it's just a division opearation but the result will be rounded
inp_transformed = inp.view(batch, tokens, n_heads, head_dim) # batch x tokens x num_heads x head_dim
print(inp_transformed.shape)
#----------------result------------------
torch.Size([2, 5, 2, 1])
Let’s unpack:
- Our input tensor from the above example has a shape of
2 x 5 x 2—> 2 batches, 5 tokens and 2 dimensions per token - But we want to reshape this tensor to add Heads hence we will split our 2 dimensions per token into 1 per token using
head_dim - With
tensor.viewmethod we have retained batch and number of tokens as they are but then inserted number of heads and then split the last dimension using head dim variable - So, the final outcome with this operation is
2 x 5 x 2 x 1—> 2 batches, 5 tokens, 2 heads and 1 dimension per head
Step 4:
In this step we still need to run one more tensor reshaping operation but this time we will use tensor.transpose function as below.
The main idea here is to transform the current shape from batch x tokens x num_heads x head_dim —> batch x num_heads x tokens x head_dim.
print(inp_transformed.shape) # batch x tokens x num_heads x head_dim
#----------------result------------------
torch.Size([2, 5, 2, 1])
final_inp = inp_transformed.transpose(1,2) # batch x num_heads x tokens x head_dim
print(final_inp.shape)
#----------------result------------------
torch.Size([2, 2, 5, 1])
Let’s unpack:
- If you notice closely, we have applied transpose operation only on second and third dimension (python starts with zero hence 1 & 2) i.e only on tokens and number of heads
- This operation will retain the
batch&head_dimas is, buttokensandnum_headswill be interchanged - So the final output tensor has a shape of
2 x 2 x 5 x 1—> 2 batches, 2 heads, 5 tokens and 1 dimension per token
Let’s select only the first batch and see the results.
print(final_inp[0])
print(final_inp[0].shape) # num_heads x tokens x head_dim
#----------------result------------------
tensor([[[0.2961],
[0.2517],
[0.0740],
[0.1366],
[0.1841]],
[[0.5166],
[0.6886],
[0.8665],
[0.1025],
[0.7264]]])
torch.Size([2, 5, 1])
Step 5:
In this step we will start computing Attention scores and the process will remain same as our Simple Attention with transpose operation.
The main point to be noted here, we are applying transpose on 2nd and 3rd dimension i.e. tokens & head_dim.
attn_scores = final_inp @ final_inp.transpose(2,3)
print(attn_scores.shape)
#----------------result------------------
torch.Size([2, 2, 5, 5])
Step 6:
In this step we will convert the computed attention scores into Attention weights using torch.softmax function (further details are already explained in my previous article).
The output dimensions will remain same but the main difference these are now probability distributions i.e. when we add the outputs by row then the totals will add to 1.
attn_weights = torch.softmax(attn_scores / inp_transformed.shape[-1] ** 0.5 , dim=-1)
print(attn_weights.shape)
#----------------result------------------
torch.Size([2, 2, 5, 5])
Step 7:
Next we will compute Context vector by Multiplying the Attention weights with our inputs like below. The output dimensions are 2 x 2 x 5 x 1 —> 2 batches, 2 heads, 5 tokens, and 1 dimension per token.
context_vec = attn_weights @ final_inp
print(context_vec.shape)
#----------------result------------------
torch.Size([2, 2, 5, 1])
Step 8:
So far we have computed all the steps related to Attention, but as you can see the output dimensions are not same as our input dimensions hence we need to convert back our results to fit the original shape.
As I have mentioned before the Attention Mechanism will not alter dimensions instead it will produce an enhanced vector with contextual information.
context_vec = context_vec.transpose(1,2) # batch x tokens x heads x head_dim
print(context_vec.shape)
#----------------result------------------
torch.Size([2, 5, 2, 1])
- With the above transpose operation, we are effectively moving back tokens next to the batch
- The output dimension is
2 x 5 x 2 x 1—> 2 batches, 5 tokens, 2 heads and 1 dimension per token
context_vec = context_vec.contiguous().view(batch, tokens, n_dim)
context_vec.shape
#----------------result------------------
torch.Size([2, 5, 2])
- With this code we will reshape our output tensor to match with our input dimensions
- The special function contiguous() is used to re-arrange the underlying memory structure
- Finally, the output shape is
2 x 5 x 2—> 2 batches, 5 tokens and 2 dimensions per token, this shape is same as our original inputs shape
In simple words, we have reshaped our original tensor and computed attention weights then context vector then transformed our outputs back to the original shape.
This concludes various steps involved in calculating an efficient Multi-Head Attention mechanism. Although we have added multiple steps but please note these are only for efficient computation and the core idea will remain same.
We need to convert these steps into a reusable class and that can be done as below with few check points and error handling along with the original Query, Key & Value mechanism.
The code might look intimidating at first but if you closely look at it they are essentially the same steps as above.
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, num_heads, qkv_bias=False):
super().__init__()
assert(d_out % num_heads == 0), "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.qkv = nn.Linear(d_in, d_out*3, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
def forward(self, x):
b, num_tokens, d_in = x.shape
keys, queries, values = torch.chunk(self.qkv(x), 3, dim=-1)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #[batch, num_tokens, d_in] --> [batch, num_tokens, mum_heads, head_dim]
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
keys = keys.transpose(1,2) # [batch, num_tokens, mum_heads, head_dim] --> [batch, num_heads, num_tokens, head_dim]
queries = queries.transpose(1,2)
values = values.transpose(1,2)
attn_scores = queries @ keys.transpose(2,3) # keys.transpose --> [batch, num_heads, head_dim, num_tokens]
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = (attn_weights @ values).transpose(1,2) #[batch, num_tokens, num_heads, head_dim]
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # [batch, num_tokens, head_dim]
context_vec = self.out_proj(context_vec)
return context_vec
Now if we initiate this class and pass our input batch the results will be as below.
torch.manual_seed(123)
b, tok, d_in = inp_batch.shape
d_out = 2
num_heads = 2
MHA = MultiHeadAttention(d_in=d_in, d_out=d_out, context_length=tok, num_heads=num_heads)
res = MHA(inp_batch)
print(res)
print(res.shape)
#----------------result------------------
tensor([[[0.2644, 0.4137],
[0.2641, 0.4117],
[0.2641, 0.4118],
[0.2630, 0.4134],
[0.2637, 0.4139],
[0.2630, 0.4128],
[0.2629, 0.4144],
[0.2639, 0.4124],
[0.2647, 0.4129]],
[[0.2644, 0.4137],
[0.2641, 0.4117],
[0.2641, 0.4118],
[0.2630, 0.4134],
[0.2637, 0.4139],
[0.2630, 0.4128],
[0.2629, 0.4144],
[0.2639, 0.4124],
[0.2647, 0.4129]]], grad_fn=<ViewBackward0>)
torch.Size([2, 9, 2])
Why do we need normalization before softmax activation
In the above class we have used the following code torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) to normalize attention scores before we pass through softmax activation.
But why do we need this and what is the rationale behind ?
Let’s first understand how softmax behaves with different magnitude inputs. Here is a Tensor with small changes in their magnitude.
torch.set_printoptions(sci_mode=False)
logits = torch.tensor([20.0, 21.5, 22.0])
print(logits)
#----------------result------------------
tensor([20.0000, 21.5000, 22.0000])
If we pass through softmax function the result will be as below:
torch.softmax(logits, dim=-1)
#----------------result------------------
tensor([0.0777, 0.3482, 0.5741])
Few things to note:
- Although the difference between 21.5 and 22.0 is just 0.5 (or 2.33%) but after the softmax activation the difference has increased to 22.59 points (57.41% - 34.82%)
- Similarly, the difference between 20.0 and 21.0 is just 1.0 but the activation has returned 7.77%
In summary, softmax will struggle to understand the magnitude when we have large and not normalized numbers hence it will inevitably score highest probability for the largest numbers in a tensor.
Now, if we normalize our inputs and pass through softmax function then the results are as follows, and you can already see the difference in probability distribution.
torch.softmax(logits / logits.sum(), dim=-1)
#----------------result------------------
tensor([0.3272, 0.3351, 0.3377])
The largest number still gets the highest probability but the difference between next number and largest number is very small.
Now if we increase our largest number further so as to create big enough gap then the softmax will behave very smoothly.
logits_x = torch.tensor([20.0, 21.5, 26.0])
torch.softmax(logits_x / logits_x.sum(), dim=-1)
#----------------result------------------
tensor([0.3210, 0.3282, 0.3508])
Had we not used normalization for the above, our results would have been very different as below, and as you can already see how disproportional the outcomes are.
logits_x = torch.tensor([20.0, 21.5, 26.0])
torch.softmax(logits_x, dim=-1)
#----------------result------------------
tensor([0.0024, 0.0110, 0.9866])
In our class we have used slightly different normalization technique to have proper gradient flow but the core idea is same.
The concept of Masked Multi-Head Attention
So far we have created a simplified version of Multi-Head attention, but it’s not yet fully ready. In this section I’ll cover some other core ideas to build a complete Masked Multi-Head Attention class that can be used for training.
As a first step let me introduce a concept called Dropout
Dropout
Dropout is a regularization technique used to control model over-fitting problem.
In simple words, we will randomly zero-out weights in a matrix based on user defined probability (hyper parameter), so that the model is constrained to learn from the available weights instead of all the weights.
Here is an interesting analogy, a group of 10 students has been selected for a specific project, the group has a mix of highly talented as well as normal IQ students.
The class teacher is fully aware of this talent mix and he also knows that if he assign a project to this group it’s very likely that the talented students will do most of the work and the others will just follow them.
So, he has decided to randomly remove/add few students from the group every day. With the process of rotation the group dynamics will change on daily basis hence there is hardly any room left to depend on talented students.
Since there is a no chance of dependency, all the students in that group has to learn and come up with solutions on their own rather than just following top few students.
The same analogy can be applied for model training as well, we don’t want only few weights to be updated in the model to get the expected outcome we want to make use of all the weights hence Dropout will play a key role.
Enough of theory, here is an example.
torch.manual_seed(123)
weight_matrix = torch.rand(5,3)
weight_matrix
#----------------result------------------
tensor([[0.2961, 0.5166, 0.2517],
[0.6886, 0.0740, 0.8665],
[0.1366, 0.1025, 0.1841],
[0.7264, 0.3153, 0.6871],
[0.0756, 0.1966, 0.3164]])
weight_matrix_drop = torch.nn.functional.dropout(weight_matrix, p=0.5)
weight_matrix_drop
#----------------result------------------
tensor([[0.5922, 0.0000, 0.5033],
[0.0000, 0.0000, 1.7330],
[0.0000, 0.2050, 0.0000],
[0.0000, 0.6305, 0.0000],
[0.1513, 0.0000, 0.0000]])
Let’s unpack
- We have initiated a random 5 x 3 Matrix and assigned it to a variable called
weight_matrix - Then we have passed this matrix through PyTorch Dropout function with a probability of 0.5 or 50%
- As a result of this operation some of our matrix elements has been set to zero randomly
- And if you notice closely, the non zero elements after the dropout are not same as our original matrix elements
- The reason for this is a scaling factor applied to retain our original input scaling, in general the remaining elements will be divided by 1 - probability i.e —> 0.2961 / (1 - 0.5) = 0.5922
Masking
Masking is a method to hide future tokens in a context window to prevent model from accessing future tokens.
Here is an example, for the first word we will not give any context but model needs to predict the next word is, and the second one will have a context of Boy & is then it needs to predict crying
Like wise for other other words.

But as you can see in our above Attention class we have calculated Attention weights for all the words in a context window and we did not put any blocker hence model can access contextual information for future tokens as well.
Let’s take this topic further with code, to keep things simple I’ll consider only first batch for this example inp_batch[0].
Step 1:
With the below step we have calculated attention weights using Matrix Multiplication and the result is 9 x 9 matrix —> 9 tokens below and 9 tokens above as shown in the picture.
attn_scores = inp_batch[0] @ inp_batch[0].T
print(attn_scores)
#----------------result------------------
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310, 0.4719, 0.8722, 1.0036],
[0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865, 0.5822, 1.2638, 1.1726],
[0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605, 0.5672, 1.2470, 1.1674],
[0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565, 0.3335, 0.7032, 0.5918],
[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935, 0.1374, 0.5934, 0.7454],
[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450, 0.4995, 0.9115, 0.6875],
[0.4719, 0.5822, 0.5672, 0.3335, 0.1374, 0.4995, 0.3113, 0.5112, 0.4434],
[0.8722, 1.2638, 1.2470, 0.7032, 0.5934, 0.9115, 0.5112, 1.0794, 1.0310],
[1.0036, 1.1726, 1.1674, 0.5918, 0.7454, 0.6875, 0.4434, 1.0310, 1.1918]])
Step 2:
As a next step we need to create a mask so that we can hide future tokens.
context_len = inp_batch.shape[1]
print(context_len)
#----------------result------------------
9
simple_mask = torch.triu(torch.ones(context_len, context_len),diagonal=1)
print(simple_mask)
#----------------result------------------
tensor([[0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 0.]])
As a first step, we have created an identical matrix using context length like so torch.ones, then we used torch.triu to extract upper triangle from the matrix.
The torch.triu method will return only the upper triangle and it will replace lower triangle with zeros i.e. it will not modify the shape of our original matrix.
Step 3:
Since we have zero’s and ones in our matrix we can easily convert this into boolean matrix and same can be used as mask like below.
attn_weights = attn_scores.masked_fill(simple_mask.bool(), -torch.inf)
print(attn_weights)
#----------------result------------------
tensor([[0.9995, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0.9544, 1.4950, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
[0.9422, 1.4754, 1.4570, -inf, -inf, -inf, -inf, -inf, -inf],
[0.4753, 0.8434, 0.8296, 0.4937, -inf, -inf, -inf, -inf, -inf],
[0.4576, 0.7070, 0.7154, 0.3474, 0.6654, -inf, -inf, -inf, -inf],
[0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450, -inf, -inf, -inf],
[0.4719, 0.5822, 0.5672, 0.3335, 0.1374, 0.4995, 0.3113, -inf, -inf],
[0.8722, 1.2638, 1.2470, 0.7032, 0.5934, 0.9115, 0.5112, 1.0794, -inf],
[1.0036, 1.1726, 1.1674, 0.5918, 0.7454, 0.6875, 0.4434, 1.0310, 1.1918]])
We have used PyTorchmasked_fill option to fill our upper triangle with negative infinity.
When I first learned about masking, I didn’t understand why -inf but then I have experimented by replacing -inf with 0.0 and then applied softmax.
Although to me 0.0 is a zero probability but when I applied softmax it has assigned a probability to it instead of 0.0.
So, if you really want to see zero probability then -inf is approximately close to zero hence softmax will consider this as a zero.
Mask without negative infinity:
With this experiment you can already see if we don’t apply negative infinity then softmax will treat that as a number and assign a probability to it.
attn_weights = attn_scores.masked_fill(simple_mask.bool(), 0.0)
print(attn_weights)
#----------------result------------------
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.3680, 0.6320, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2284, 0.3893, 0.3822, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.2046, 0.2956, 0.2915, 0.2084, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1753, 0.2250, 0.2269, 0.1570, 0.2158, 0.0000, 0.0000, 0.0000, 0.0000],
[0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896, 0.0000, 0.0000, 0.0000],
[0.1496, 0.1671, 0.1646, 0.1303, 0.1071, 0.1538, 0.1274, 0.0000, 0.0000],
[0.1176, 0.1740, 0.1711, 0.0993, 0.0890, 0.1223, 0.0820, 0.1447, 0.0000],
[0.1200, 0.1421, 0.1414, 0.0795, 0.0927, 0.0875, 0.0685, 0.1234, 0.1449]])
In this step we will apply softmax to convert the above attention weights into probabilities as below.
attn_scores = torch.softmax(attn_weights, dim=-1)
print(attn_scores)
#----------------result------------------
tensor([[0.2536, 0.0933, 0.0933, 0.0933, 0.0933, 0.0933, 0.0933, 0.0933, 0.0933],
[0.1399, 0.1822, 0.0968, 0.0968, 0.0968, 0.0968, 0.0968, 0.0968, 0.0968],
[0.1232, 0.1447, 0.1437, 0.0981, 0.0981, 0.0981, 0.0981, 0.0981, 0.0981],
[0.1210, 0.1325, 0.1320, 0.1215, 0.0986, 0.0986, 0.0986, 0.0986, 0.0986],
[0.1179, 0.1239, 0.1241, 0.1157, 0.1227, 0.0989, 0.0989, 0.0989, 0.0989],
[0.1138, 0.1232, 0.1225, 0.1142, 0.1093, 0.1197, 0.0991, 0.0991, 0.0991],
[0.1153, 0.1173, 0.1170, 0.1131, 0.1105, 0.1157, 0.1127, 0.0992, 0.0992],
[0.1117, 0.1182, 0.1178, 0.1097, 0.1085, 0.1122, 0.1078, 0.1148, 0.0993],
[0.1121, 0.1146, 0.1145, 0.1076, 0.1090, 0.1085, 0.1064, 0.1124, 0.1149]])
Multi-Head Attention with Dropout and Masking
Here is a final MHA class with all the required functionalities to start training a Language model. Off course the class might appears to be confusing but do remember all the steps have been unpacked and discussed above.
class MultiHeadAttentionFinal(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert(d_out % num_heads == 0), "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.qkv = nn.Linear(d_in, d_out*3, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length),diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys, queries, values = torch.chunk(self.qkv(x), 3, dim=-1)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #[batch, num_tokens, d_in] --> [batch, num_tokens, mum_heads, head_dim]
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
keys = keys.transpose(1,2) # [batch, num_tokens, mum_heads, head_dim] --> [batch, num_heads, num_tokens, head_dim]
queries = queries.transpose(1,2)
values = values.transpose(1,2)
attn_scores = queries @ keys.transpose(2,3) # keys.transpose --> [batch, num_heads, head_dim, num_tokens]
mask_bool = self.mask.bool()[:num_tokens,:num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = (attn_weights @ values).transpose(1,2) #[batch, num_tokens, num_heads, head_dim]
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # [batch, num_tokens, head_dim]
context_vec = self.out_proj(context_vec)
return context_vec
Now we can initialize this Class and run it through our inputs as below:
torch.manual_seed(123)
b, tok, d_in = inp_batch.shape
d_out = 2
num_heads = 2
MHA = MultiHeadAttentionFinal(d_in=d_in, d_out=d_out, context_length=tok, dropout=0.5, num_heads=num_heads)
res = MHA(inp_batch)
print(res)
print(res.shape)
#----------------result------------------
tensor([[[0.3190, 0.4858],
[0.2940, 0.3947],
[0.2853, 0.3637],
[0.2695, 0.3879],
[0.2643, 0.3944],
[0.2577, 0.4025],
[0.2554, 0.4284],
[0.2581, 0.4190],
[0.2647, 0.4129]],
[[0.3190, 0.4858],
[0.2940, 0.3947],
[0.2853, 0.3637],
[0.2695, 0.3879],
[0.2643, 0.3944],
[0.2577, 0.4025],
[0.2554, 0.4284],
[0.2581, 0.4190],
[0.2647, 0.4129]]], grad_fn=<ViewBackward0>)
torch.Size([2, 9, 2])
Summary
- In simple words Multi-Head Attention is an extension of a Simple Attention class but repeated multiple times
- Normalization is a technique used to smooth out inputs for better training
- Dropout is technique used to regularize model to use all the available weights instead of creating few weight dependencies
- Masking is a method used to prevent model from overlooking future token information
Next steps:
- Compare our manually created MHA class with the PyTorch implementation and review the outputs