Back to blogs

Part 1: The Attention Mechanism - Building the Core of the Transformer

May 10, 2024

Part 1: The Attention Mechanism - Building the Core of the Transformer

We're on a quest to build something truly remarkable – an architecture capable of mirroring the astonishing flexibility and learning capacity of the human brain. Forget the black box; I'm talking about a system built on fundamental, interpretable components, eventually leveraging the power of Mixture-of-Experts (MoE) to simulate specialized cognitive functions. But before we run, we walk. And before we compose, we master the primitives.

This series kicks off with what I consider one of the most elegant and crucial inventions in modern AI: The Attention Mechanism. If you've ever found yourself slogging through an academic paper, only to focus intensely on a few key sentences that unlock its meaning, you've experienced human attention. The machine learning counterpart, the Transformer's self-attention, operates with a similar, albeit mathematical, intuition. This isn't about some bloated framework abstracting away the magic; this is about getting our hands dirty with raw tensor operations, understanding exactly how the model 'focuses'.

Why Attention? The Bottleneck Problem and Beyond

Traditional sequential models like RNNs and LSTMs were groundbreaking, but they had a glaring limitation: the "bottleneck" problem. To generate a prediction, the entire input sequence had to be compressed into a single fixed-size vector. Imagine trying to summarize an entire novel into a single sentence before deciding on its genre. Information inevitably gets lost, especially over long sequences.

This is where attention shines. Instead of squashing everything into one vector, attention allows the model to dynamically weigh the importance of different parts of the input sequence when processing each element. For me, this immediately clicked with the concept of a 'working memory' in human cognition – a dynamic spotlight that can shift focus. This isn't just about better NLP; it's about building a fundamental primitive for any intelligent agent to selectively process information, a cornerstone for future MoE architectures where experts need to 'attend' to relevant input segments.

The Intuition: Query, Key, Value – A Database Analogy

Let's demystify the core components: Query, Key, and Value (Q, K, V).

Imagine you're searching a database.

  • Query (Q): This is your search term. "What am I looking for?"
  • Key (K): These are the labels or identifiers attached to items in the database. "What do I have that might match the query?"
  • Value (V): These are the actual items or information associated with the keys. "If there's a match, what's the information I get back?"

In the context of self-attention for a sequence of words (or tokens), each word in the sequence will play all three roles. When we're processing a specific word (our 'query' word), we want to compare it against all other words (their 'keys') in the sequence to determine how relevant they are. Once we know the relevance, we take a weighted sum of their 'values' to create a new, context-aware representation for our query word.

This dynamic weighting is achieved through matrix multiplication, leveraging the incredible efficiency of modern GPUs. Forget slow, iterative operations; we're dealing with vectorized power.

Building Scaled Dot-Product Attention from Scratch

Let's dive into the code. We'll start with NumPy for maximum transparency into the matrix operations, then transition to PyTorch for a more typical deep learning setup. We'll operate on token embeddings, which are dense vector representations of our words.

Assume an input sequence of length seq_len, where each token is represented by an embedding of dimension d_model.

import numpy as np
import torch
import torch.nn as nn
import math
 
# --- Hyperparameters ---
seq_len = 5       # Example: "The quick brown fox jumps"
d_model = 8       # Embedding dimension for each token
d_k = d_model     # Often d_k = d_v = d_model / num_heads, but for simplicity here, d_k = d_model
d_v = d_model     # The dimension of the Value vectors
 
# --- 1. Synthetic Input Embeddings ---
# batch_size=1 for now, but attention scales perfectly for batches.
# Shape: (batch_size, seq_len, d_model)
input_embeddings = np.random.rand(1, seq_len, d_model)
print(f"Input Embeddings Shape: {input_embeddings.shape}\n")

Step 1: Linear Projections for Q, K, V

The first step is to project our input embeddings into three distinct spaces: Query, Key, and Value. This is done using three separate linear layers (weight matrices), allowing the model to learn different transformations for each role.

# --- Weights for Linear Projections ---
# These would be learned during training. For now, random initialization.
# Shape: (d_model, d_k) for Q and K, (d_model, d_v) for V
W_q = np.random.rand(d_model, d_k)
W_k = np.random.rand(d_model, d_k)
W_v = np.random.rand(d_model, d_v)
 
# --- Compute Q, K, V Matrices ---
# (batch_size, seq_len, d_model) @ (d_model, d_k) -> (batch_size, seq_len, d_k)
Q = input_embeddings @ W_q
K = input_embeddings @ W_k
V = input_embeddings @ W_v
 
print(f"Q Matrix Shape: {Q.shape}")
print(f"K Matrix Shape: {K.shape}")
print(f"V Matrix Shape: {V.shape}\n")

Each row in Q, K, V corresponds to a token's Query, Key, or Value vector. These are not just copies; they're transformed representations, allowing the model to focus differently for each role.

Step 2: Calculate Attention Scores (Query dot Key Transpose)

This is where the magic happens. We compute the dot product between the Query matrix and the transpose of the Key matrix. For each token's query vector, we calculate its similarity with every other token's key vector. A higher dot product means higher similarity.

# --- Attention Scores ---
# (batch_size, seq_len, d_k) @ (batch_size, d_k, seq_len) -> (batch_size, seq_len, seq_len)
# Using `transpose(1, 2)` to swap the last two dimensions for K.
attention_scores = Q @ K.transpose(0, 2, 1) # K.T in NumPy often means (..., K.T), need explicit permute.
# For a single batch, it's simpler: Q[0] @ K[0].T
# For full batching: np.einsum('bsd,bTd->bsT', Q, K) is often more robust, but @ works too.
 
print(f"Attention Scores Shape: {attention_scores.shape}")
print(f"Sample Attention Scores (first token's scores):\n{attention_scores[0, 0, :]}\n")

The resulting attention_scores matrix has shape (batch_size, seq_len, seq_len). For a given batch and a given query token i, attention_scores[batch_idx, i, j] represents how much token i "attends" to token j.

Step 3: Scaling

The dot products can grow quite large, especially with high d_k values. This can push the softmax function into regions where its gradient is extremely small, hindering learning. We scale the scores by

\frac{1}{sqrt{d_k}}
. This is a critical detail, often overlooked when just reading the formula, but vital for stable training.

# --- Scaling ---
scaled_attention_scores = attention_scores / np.sqrt(d_k)
print(f"Sample Scaled Attention Scores:\n{scaled_attention_scores[0, 0, :]}\n")

Step 4: Softmax

We apply the softmax function row-wise (over the last dimension) to the scaled scores. This converts them into probability distributions, ensuring that the attention weights for each query sum to 1. Now, attention_weights[batch_idx, i, j] truly represents the probability that token i focuses on token j.

# --- Softmax Function ---
def softmax(x, axis=-1):
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) # Subtract max for numerical stability
    return e_x / e_x.sum(axis=axis, keepdims=True)
 
attention_weights = softmax(scaled_attention_scores, axis=-1)
print(f"Attention Weights Shape: {attention_weights.shape}")
print(f"Sample Attention Weights (first token's):\n{attention_weights[0, 0, :]}")
print(f"Sum of weights for first token: {np.sum(attention_weights[0, 0, :]):.4f}\n")

Notice how the weights sum to approximately 1.0 for each query token. This confirms they are probability distributions.

Step 5: Weighted Sum of Values

Finally, we multiply these attention weights by the Value matrix. Each query token's new representation is a weighted sum of all Value vectors in the sequence, where the weights are precisely the attention probabilities we just calculated. This means words highly relevant to our query word (high attention weight) will contribute more to its new, context-aware representation.

# --- Weighted Sum of Values ---
# (batch_size, seq_len, seq_len) @ (batch_size, seq_len, d_v) -> (batch_size, seq_len, d_v)
output = attention_weights @ V
print(f"Output Shape: {output.shape}\n")
print("Self-Attention Mechanism Complete (NumPy)!\n")

The output matrix (batch_size, seq_len, d_v) contains the new, contextually enriched representations for each token in the input sequence. Each token now "knows" about the relevant parts of its surrounding context.

PyTorch Implementation: The Real-World Scenario

While NumPy is great for understanding, PyTorch is what we use for accelerated computing. The operations remain identical, just expressed with PyTorch tensors.

# --- PyTorch Implementation ---
print("--- PyTorch Implementation ---")
 
# Convert NumPy arrays to PyTorch tensors
input_embeddings_torch = torch.tensor(input_embeddings, dtype=torch.float32)
 
# Linear layers (equivalent to W_q, W_k, W_v)
# We use nn.Linear for standard PyTorch layers.
# Note: In a real Transformer, these would be `nn.Linear(d_model, d_k)` etc.
# For simplicity, we'll manually create weight tensors matching our NumPy example,
# but usually, you'd instantiate nn.Linear layers.
linear_q = nn.Linear(d_model, d_k, bias=False)
linear_k = nn.Linear(d_model, d_k, bias=False)
linear_v = nn.Linear(d_model, d_v, bias=False)
 
# Initialize weights to match our NumPy random ones for consistency (optional)
with torch.no_grad():
    linear_q.weight.copy_(torch.tensor(W_q.T, dtype=torch.float32)) # PyTorch linear expects (out_features, in_features)
    linear_k.weight.copy_(torch.tensor(W_k.T, dtype=torch.float32))
    linear_v.weight.copy_(torch.tensor(W_v.T, dtype=torch.float32))
 
# 1. Compute Q, K, V
Q_torch = linear_q(input_embeddings_torch)
K_torch = linear_k(input_embeddings_torch)
V_torch = linear_v(input_embeddings_torch)
 
print(f"Q_torch Shape: {Q_torch.shape}")
print(f"K_torch Shape: {K_torch.shape}")
print(f"V_torch Shape: {V_torch.shape}\n")
 
# 2. Calculate Attention Scores
# Q @ K.transpose(-2, -1) performs (batch, seq_len, d_k) @ (batch, d_k, seq_len)
attention_scores_torch = torch.matmul(Q_torch, K_torch.transpose(-2, -1))
print(f"Attention Scores_torch Shape: {attention_scores_torch.shape}\n")
 
# 3. Scaling
scaled_attention_scores_torch = attention_scores_torch / math.sqrt(d_k)
 
# 4. Softmax
attention_weights_torch = torch.softmax(scaled_attention_scores_torch, dim=-1)
print(f"Attention Weights_torch Shape: {attention_weights_torch.shape}")
print(f"Sum of weights for first token (PyTorch): {torch.sum(attention_weights_torch[0, 0, :]):.4f}\n")
 
 
# 5. Weighted Sum of Values
output_torch = torch.matmul(attention_weights_torch, V_torch)
print(f"Output_torch Shape: {output_torch.shape}\n")
 
print("Self-Attention Mechanism Complete (PyTorch)!")
 
# Verify results are close between NumPy and PyTorch
assert np.allclose(output, output_torch.numpy(), atol=1e-6)
print("NumPy and PyTorch outputs are consistent!\n")

What I Learned and Why This Matters

Building attention from the ground up, directly manipulating tensors, solidifies its mechanics in a way no high-level API ever could.

  1. The Power of Matrix Multiplication: This entire mechanism hinges on efficient matrix operations. This is why GPUs are so critical for deep learning; they parallelize these computations beautifully. As competitive programmers, optimizing for vectorized operations over explicit loops is second nature, and attention is a prime example of this principle.
  2. Dynamic Context: The genius lies in creating context-aware representations dynamically. Each word's new vector isn't static; it's a unique blend of itself and other words, weighted by their learned relevance. This is a massive leap over fixed-context models.
  3. No Bloated Frameworks Needed: You don't need LangChain's layers of abstraction to understand or implement this. The core is elegant, mathematical, and surprisingly simple at its heart. Direct tensor manipulation using PyTorch or NumPy gives you full control and clarity, which is invaluable for debugging and optimization.
  4. A Stepping Stone to Intelligence: This attention mechanism is more than just a trick for language models. It's a general-purpose primitive for selective information processing. In my pursuit of brain-like architectures, this concept of dynamically focusing on relevant parts of an input is fundamental. It's how an "expert" in a future MoE system could decide which features of the input are most pertinent to its specialization. This is the bedrock upon which more complex cognitive abilities will be built.

Next up, we'll expand this single-head attention into Multi-Head Attention, adding more parallelism and richer contextual understanding, laying the final piece for the full Transformer block. Stay tuned.