From last couple of months I have been learning the core building blocks of LLM’s using a well known book called Build Large Language Model from Scratch.

The book is very well written and highly recommended for someone who is interested in learning the internal architecture of LLM’s, but please note you will need to put some serious time & effort to get hold of it.

This writing is inspired by the book hence full credit to the author Sebastian Raschka.

So, what is an Attention ?

The Self Attention is one of the core building blocks in modern LLM architecture, and not to mention you might have heard a paper called Attention is all you need.

In general, Self Attention is a method to identify relative importance of a particular word with every other word in a sentence.

For example, in this sentence Boy is crying because he wants an ice cream the reference he is based on the first word Boy but suppose if we have a Girl then we would have called she wants...

how do we do this..?

We will make this judgment purely based on context about a particular scenario. We do this all the time effortlessly and very naturally but how can we instruct machines to do the same..?

Well, that’s when the concept of Self Attention comes handy.

How does it work ?

Let’s take the above sentence as an example.

As a first step, we need to convert these words into some sort of numeric representation - we cannot perform calculations on a textual data and computers do not understand text like us.

Step 1:

I have randomly allocated 3 dimensional vector for each word i.e. the word Boy is represented as [0.43, 0.15, 0.89] and the other words as below.

But why 3 dimensions ? well, it’s random for now just to keep things simple.

The process of converting words into numeric representation is called Tokenization and then each token or a word will be assigned a vector with randomly initiated (initially) numbers is called Embeddings.

Ah I see too many jargon’s, just think of it as splitting words in a sentence and assigning each word a set of numeric values.

Let’s assign this entire matrix to a variable calledinp.


    inp = torch.tensor([[0.43, 0.15, 0.89], # Boy  
                        [0.55, 0.87, 0.66], # is 
                        [0.57, 0.85, 0.64], # crying 
                        [0.22, 0.58, 0.33], # because 
                        [0.77, 0.25, 0.10], # he 
                        [0.05, 0.80, 0.55], # wants
                        [0.02, 0.30, 0.47], # an
                        [0.47, 0.67, 0.64], # ice
                        [0.77, 0.33, 0.70]] # cream
                        

Step 2:

Let’s say, I want to understand how the word he relates to every other word in the above sentence. Let’s call this word as a query and assign this vector to a query variable like below.

query = inp[4]
print(query)

# -----------result--------------
# tensor([0.7700, 0.2500, 0.1000])

Step 3:

Since we want to understand the word relationships we can simply do Matrix Multiplication.

But what is Matrix Multiplication?

In simple words, it’s a process of multiplying elements one word vector with another word vector and then adding them together, in general the higher the score the greater the relationship.

Here is an example, where we have multiplied the boy vector with he vector and the final result was positive number hence for now we can assume the word he has approximately 0.457 probabilistic dependency.

boy = [0.43, 0.15, 0.89]
he =  [0.77, 0.25, 0.10]

step1 =  (0.43*0.77) + (0.15*0.25) + (0.89*0.10)
step2 =  0.3311 + 0.0375 + 0.089
result = 0.4576

So far we have done the calculation for one word, but we can do to Matrix Multiplication to identify the relationship between query and every other word.

Before we do Matrix Multiplication please note, we need to Transpose our inp matrix first (why? I assume you already know this but if not please Google it).

The result will be a vector with 9 elements and each element represents a single word from the above sentence, as you can see here the first value is exactly matching with our above calculation of 0.4576.

let’s assign this to a variable called attn_scores

attn_scores = query @ inp.T
print(attn_scores)

# -----------result--------------
 tensor([0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935, 0.1374, 0.5934, 0.7454])

Step 4:

Now we need to find a way to convert these results in to a probability distribution so that the total will add up to 1, for now the total sum of this vector is 4.6625 which is not a probability distribution.

To do that we can use Softmax function as below and now if we take a sum, the total will add up to 1.


attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
attn_weights.sum()

# -----------result--------------
tensor([0.1025, 0.1315, 0.1326, 0.0918, 0.1262, 0.0870, 0.0744, 0.1174, 0.1367])
tensor(1.0000)

Let’s intuitively understand what these numbers are and why do we need them. As you can see, we have 9 weights and each weight represents one word from the above sentence.

Since we are interested in our query he and how it is derived, we can think of it as the word he is 10.25% dependent on Boy similarly 13.15% dependence on is and 13.26% dependence on crying and so on.

Step 5:

So far, we have calculated attention weights i.e. relative dependencies based on our query but we somehow need to embed this valuable information into our original query vector.

We can achieve this by Matrix multiplying the computed attention weights with all our inputs to generate a context vector as below.

attn_weights @ inp

# -----------result--------------
tensor([0.4756, 0.5429, 0.5594])

But let’s unpack a bit, what’s happening here..?

# Our input vector
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500],
        [0.0200, 0.3000, 0.4700],
        [0.4700, 0.6700, 0.6400],
        [0.7700, 0.3300, 0.7000]])
        
# This is our weight vector
tensor([0.1025, 0.1315, 0.1326, 0.0918, 0.1262, 0.0870, 0.0744, 0.1174, 0.1367])
        
# Since we have three dimentions in our input let's take the first column 
# which represents each word with one number

# Now we can run our calculation as below 
(0.4300*0.1025) + (0.5500*0.1315) + (0.5700*0.1326) + (0.2200*0.0918) + (0.7700*0.1262) + 
(0.0500*0.0870) + (0.0200*0.0744) + (0.4700*0.1174)+ (0.7700*0.1367)  = 0.4756

# The result from the above calculation is 0.4756 which is same as the above calculation

We can think of this as, we are taking information from each word based on a attention weight that we calculated before.

If we run same calculations for the next two dimensions then we will get the same results as our Matrix Multiplication.

As you can see, our output is a three dimensional vector which has a same shape as our query vector.

Intuitively, we can think of it as a enhanced vector representation with contextual information based on our sentence context.

Computing context vector for all words

So far, we have computed enhanced vector for only one word i.e. our query but we need similar vector representation for all other words and this can be achieved with the below steps.

# Step 1 ---> Computing attention scores

attn_scores = inp @ inp.T
print(attn_scores.shape)
print(attn_scores)

# -----------result--------------

torch.Size([9, 9])

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310, 0.4719, 0.8722, 1.0036],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865, 0.5822, 1.2638, 1.1726],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605, 0.5672, 1.2470, 1.1674],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565, 0.3335, 0.7032, 0.5918],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935, 0.1374, 0.5934, 0.7454],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450, 0.4995, 0.9115, 0.6875],
        [0.4719, 0.5822, 0.5672, 0.3335, 0.1374, 0.4995, 0.3113, 0.5112, 0.4434],
        [0.8722, 1.2638, 1.2470, 0.7032, 0.5934, 0.9115, 0.5112, 1.0794, 1.0310],
        [1.0036, 1.1726, 1.1674, 0.5918, 0.7454, 0.6875, 0.4434, 1.0310, 1.1918]])


# Step 2 ---> Computing attention weights

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights.shape)
print(attn_weights)

# -----------result--------------

torch.Size([9, 9])

tensor([[0.1381, 0.1320, 0.1304, 0.0818, 0.0803, 0.0955, 0.0815, 0.1216, 0.1387],
        [0.0951, 0.1633, 0.1601, 0.0851, 0.0743, 0.1085, 0.0656, 0.1296, 0.1183],
        [0.0953, 0.1625, 0.1595, 0.0852, 0.0760, 0.1073, 0.0655, 0.1293, 0.1194],
        [0.0979, 0.1415, 0.1395, 0.0997, 0.0861, 0.1173, 0.0850, 0.1230, 0.1100],
        [0.1025, 0.1315, 0.1326, 0.0918, 0.1262, 0.0870, 0.0744, 0.1174, 0.1367],
        [0.0954, 0.1505, 0.1466, 0.0979, 0.0681, 0.1306, 0.0837, 0.1263, 0.1010],
        [0.1150, 0.1284, 0.1265, 0.1001, 0.0823, 0.1182, 0.0979, 0.1196, 0.1118],
        [0.1034, 0.1529, 0.1504, 0.0873, 0.0782, 0.1075, 0.0720, 0.1272, 0.1212],
        [0.1200, 0.1421, 0.1414, 0.0795, 0.0927, 0.0875, 0.0685, 0.1234, 0.1449]])

# Step 3 ---> Computing Context vector

context_vec = attn_weights @ inp
print(context_vec.shape)
print(context_vec)

# -----------result--------------

torch.Size([9, 3])

tensor([[0.4565, 0.5421, 0.5943],
        [0.4567, 0.5928, 0.5867],
        [0.4579, 0.5912, 0.5860],
        [0.4378, 0.5738, 0.5715],
        [0.4756, 0.5429, 0.5594],
        [0.4266, 0.5912, 0.5798],
        [0.4278, 0.5562, 0.5752],
        [0.4536, 0.5793, 0.5849],
        [0.4745, 0.5521, 0.5873]])

If you notice carefully, the final result has the same shape as our original input matrix i.e. [9,3] —> 9 words and 3 Embedding dimensions per word.

Also you can see the 5th vector (python starts with 0 not 1) is matching with our manually calculated vector.

Attention with Query, Key & Value parameters

So far, we have worked on a simplified version of Attention mechanism but in reality it’s bit more complicated, so let’s now explore how to build a version that is closer to the one that is used for training models.

This time let’s take a different approach, instead of building each step one by one I’ll present fully coded Attention Mechanism below then we will unpack one step at a time.

class SelfAttention(nn.Module):
    def __init__(self, d_in, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_in, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_in, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_in, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        quires = self.W_query(x)
        values = self.W_value(x)

        attn_scores = quires @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In the above python class we have created three separate variables called query, keyand value let’s understand the intuition behind these variables.

When I first started learning about Attention I was totally confused on why do we need these variables, are these arbitrary variables or any specific relevance. I have asked same question to ChatGPT and the answer was very interesting.

Let’s say you are in a Library and looking for a book titled The Origin of Species you will ask Librarian about this book and he will start searching by it’s title to see in which rack he can find this.

  • The very first step of asking about a particular book is called query

  • Then the book title is called key, this will act as an identifier.

  • Next, he will walk over to a particular rack and retrieve the book, the actual book itself is called value or something that we are looking for.

In an information retrieval systems these concepts are very often used, for example let’s say if you are searching for a particular customer by Mobile number. The Query is your customer, Key is your customer mobile number and then the customer information is the Value.

Now let’s unpack the above code.

Step 1:

Since we need to transform our input Embeddings into three variables, we just need a simple Linear Layer per variable (we can also do this step without a Linear layer by manually initiating weights).

In this case, we have created three variables with user defined input dimensions and then we use the same as an output dimension and the bias has been set to false.

        self.W_query = nn.Linear(d_in, d_in, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_in, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_in, bias=qkv_bias)

Step 2:

In this step we will do Matrix Multiplication with input Embeddings to transform our inputs into three distinct variables.

        keys = self.W_key(x)
        quires = self.W_query(x)
        values = self.W_value(x)

Step 3:

In our third step we will do matrix multiplication with variable Query and Key, this is where we will try to find the relevance of every word with every other word in a sequence. Conceptually, this is same as directly multiplying input embeddings with itself like inp @ inp.T

 attn_scores = quires @ keys.T

Step 4:

In this step we will convert our attention scores into probabilities using Softmax function. But if you notice we are dividing attention scores by square root of keys last dimensionkeys.shape[-1] ** 0.5. Let’s do not spend too much time here but just think of it as some sort of normalization.

attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim=-1)

Step 5:

In our last step, we will do matrix Multiplication between Attention weights and the values. This is similar to what we have done before.

The final result will be an enhanced context matrix.

context_vec = attn_weights @ values

Let’s try our Attention Class


torch.manual_seed(123)
SA = SelfAttention(3) # input Dimension = 3 
result = SA(inp)

print(result)
print(result.shape)

# -----------result--------------

tensor([[ 0.2654,  0.4334, -0.1496],
        [ 0.2663,  0.4363, -0.1505],
        [ 0.2663,  0.4363, -0.1505],
        [ 0.2666,  0.4380, -0.1534],
        [ 0.2661,  0.4364, -0.1522],
        [ 0.2667,  0.4382, -0.1530],
        [ 0.2662,  0.4371, -0.1535],
        [ 0.2662,  0.4360, -0.1508],
        [ 0.2656,  0.4338, -0.1492]], grad_fn=<MmBackward0>)

torch.Size([9, 3])

As you can see here the shape of our output matrix is same as our input matrix [9, 3] i.e. 9 tokens and each token has 3 dimensions.

Attention for batched inputs

The above class works fine when we input all our tokens in one go, but in reality this is not feasible hence it’s a very common practice to batch inputs and then submit for training.

Let’s create a simple batch with the same inputs by duplicating it as below.

inp_batch = torch.stack([inp,inp],dim=0)
print(inp_batch)
print(inp_batch.shape)

# -----------result--------------

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500],
         [0.0200, 0.3000, 0.4700],
         [0.4700, 0.6700, 0.6400],
         [0.7700, 0.3300, 0.7000]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500],
         [0.0200, 0.3000, 0.4700],
         [0.4700, 0.6700, 0.6400],
         [0.7700, 0.3300, 0.7000]]])
         
torch.Size([2, 9, 3])

The shape of this batch is [2, 9, 3] —> there are two batches, each batch has 9 tokens/words and each token has a 3 dimensional vector.

We cannot use this input batch directly on our previous Attention class as it cannot handle batched inputs, so let’s modify as below.


class SelfAttentionFinal(nn.Module):
    def __init__(self, d_in, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_in, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_in, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_in, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        quires = self.W_query(x)
        values = self.W_value(x)

        attn_scores = quires @ keys.transpose(1,2)
        attn_weights = torch.softmax(attn_scores/keys.shape[-1] ** 0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

Here the main difference vs the previous one is how we transpose our keys. In this case we have used transpose method as follows attn_scores = quires @ keys.transpose(1,2)

If we refer back to the size of this input we can see that the first dimension is a batch so we don’t want to touch this, we want to transpose the second dimension i.e. Tokens and the third dimension which is token embeddings

Let’s unpack how transpose works here

As you can see below, the transpose method is essentially converting the input batch from batch x tokens x token_dim into batch x token_dim x tokens.

# This is how our input batch looks before the transpose.
print(inp_batch)
print(result.shape)

# -----------result--------------
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500],
         [0.0200, 0.3000, 0.4700],
         [0.4700, 0.6700, 0.6400],
         [0.7700, 0.3300, 0.7000]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500],
         [0.0200, 0.3000, 0.4700],
         [0.4700, 0.6700, 0.6400],
         [0.7700, 0.3300, 0.7000]]])

torch.Size([2, 9, 3])

# Let's apply Transpose on last two dimensions
transposed_batch = inp_batch.transpose(1,2)
print(transposed_batch)
print(transposed_batch.shape)

# -----------result--------------
tensor([[[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500, 0.0200, 0.4700, 0.7700],
         [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000, 0.3000, 0.6700, 0.3300],
         [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500, 0.4700, 0.6400, 0.7000]],

        [[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500, 0.0200, 0.4700, 0.7700],
         [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000, 0.3000, 0.6700, 0.3300],
         [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500, 0.4700, 0.6400, 0.7000]]])

torch.Size([2, 3, 9])

Let’s try it out

Now if we use this class as below with manual seed, the output values are same as the previous one indicating its working as expected.

And not to mention the output shape matches with the input shape.


torch.manual_seed(123)
SA = SelfAttentionFinal(3)
result = SA(inp_batch)
print(result)
print(result.shape)

# -----------result--------------

tensor([[[ 0.2654,  0.4334, -0.1496],
         [ 0.2663,  0.4363, -0.1505],
         [ 0.2663,  0.4363, -0.1505],
         [ 0.2666,  0.4380, -0.1534],
         [ 0.2661,  0.4364, -0.1522],
         [ 0.2667,  0.4382, -0.1530],
         [ 0.2662,  0.4371, -0.1535],
         [ 0.2662,  0.4360, -0.1508],
         [ 0.2656,  0.4338, -0.1492]],

        [[ 0.2654,  0.4334, -0.1496],
         [ 0.2663,  0.4363, -0.1505],
         [ 0.2663,  0.4363, -0.1505],
         [ 0.2666,  0.4380, -0.1534],
         [ 0.2661,  0.4364, -0.1522],
         [ 0.2667,  0.4382, -0.1530],
         [ 0.2662,  0.4371, -0.1535],
         [ 0.2662,  0.4360, -0.1508],
         [ 0.2656,  0.4338, -0.1492]]], grad_fn=<UnsafeViewBackward0>)

torch.Size([2, 9, 3])

Conclusion

  • We have a implemented a simple Attention mechanism as a starting point and explored the inner workings of it.
  • Slowly we have also converted this knowledge into a more realistic one with three trainable variables.
  • Lastly, we have crated a modified version of Self Attention to handle batched inputs using Transpose method.

As a closing note, I have deliberately omitted the concept of masking to keep things simple, as the main idea behind this post is to explore the inner workings of Self Attention Mechanism.