Loading...
Development

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch

"Attention is All You Need" — Multi-Head & Self-Attention

Complete Module: Parallelism, Divide & Conquer, Multi-Head from Scratch


Module Objective

Master Multi-Head Self-Attention — the core parallel computation engine of Transformers — with math, code, intuition, and divide-and-conquer parallelism.


1. Self-Attention: One Token Talks to All

Self-Attention = $ Q = K = V $
Every token attends to every other token in the same sequence.

# Input: [batch, seq_len, d_model]
X → Linear → Q, K, V → Attention(Q, K, V)

2. Why Multi-Head? Divide & Conquer

ProblemSolution
One attention head = one perspectiveMultiple heads = multiple subspaces
Risk of missing relationsParallel views → richer representation

"Let the model attend to information from different representation subspaces at different positions."
Vaswani et al., 2017


3. Multi-Head Attention — The Formula

$$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O $$

$$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$

Where:

  • $ h $ = number of heads
  • $ d_k = d_v = d_{\text{model}} / h $
  • $ W_i^Q, W_i^K, W_i^V \in \mathbb{R}^{d \times d_k} $
  • $ W^O \in \mathbb{R}^{h d_v \times d} $

4. Step-by-Step: From Single to Multi-Head

StepOperationShape
1Project $ X $ → $ Q, K, V $$ (B, N, d) $
2Split into $ h $ heads$ (B, h, N, d/h) $
3Parallel attention$ h $ heads → $ (B, h, N, d/h) $
4Concat + Linear$ \to (B, N, d) $

5. Multi-Head Attention — From Scratch (PyTorch)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # d_v = d_k
        
        # Learnable projections
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = (self.d_k) ** 0.5
        
    def split_heads(self, x):
        """Split last dim into (num_heads, d_k)"""
        batch, seq_len, _ = x.shape
        x = x.view(batch, seq_len, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (B, h, N, d_k)
    
    def combine_heads(self, x):
        """Combine heads back to (B, N, d_model)"""
        batch, _, seq_len, _ = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(batch, seq_len, self.d_model)
    
    def scaled_dot_product(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, h, N, N)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        return torch.matmul(attn, V), attn  # (B, h, N, d_k), (B, h, N, N)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 1. Linear projections
        Q = self.W_q(Q)  # (B, N, d)
        K = self.W_k(K)
        V = self.W_v(V)
        
        # 2. Split into heads
        Q = self.split_heads(Q)  # (B, h, N, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)
        
        # 3. Apply attention in parallel
        attn_output, attn_weights = self.scaled_dot_product(Q, K, V, mask)
        # attn_output: (B, h, N, d_k)
        
        # 4. Combine heads
        output = self.combine_heads(attn_output)  # (B, N, d)
        
        # 5. Final linear
        output = self.W_o(output)
        
        return output, attn_weights

6. Self-Attention = Multi-Head(Q=X, K=X, V=X)

# Self-Attention
x = torch.randn(2, 10, 512)  # (batch, seq_len, d_model)
mha = MultiHeadAttention(d_model=512, num_heads=8)
output, attn = mha(x, x, x)  # Q=K=V=x
print(output.shape)  # (2, 10, 512)

7. Parallelism: Divide & Conquer

Hardware View (GPU)

Input X → [Linear Q] → Split → [Head 1] → Attention
          [Linear K] → Split → [Head 2] → Attention → Concat → W^O
          [Linear V] → Split → [Head 3] → Attention
                             ...
                             [Head 8] → Attention

All 8 heads run in parallel on GPU
Memory: $ O(h \cdot N^2) $ → still $ O(N^2) $, but richer features


8. Visualization: Multi-Head Attention Maps

import matplotlib.pyplot as plt
import seaborn as sns

# Dummy input
x = torch.randn(1, 5, 64)
mha = MultiHeadAttention(d_model=64, num_heads=8)
_, attn_weights = mha(x, x, x)  # (B, h, N, N)

# Plot all heads
fig, axes = plt.subplots(2, 4, figsize=(16, 6))
axes = axes.flatten()

for i in range(8):
    sns.heatmap(
        attn_weights[0, i].detach().cpu(),
        ax=axes[i],
        cmap="viridis",
        cbar=False
    )
    axes[i].set_title(f"Head {i+1}")
    axes[i].set_xticks([])
    axes[i].set_yticks([])

plt.suptitle("Multi-Head Attention Weights (8 Heads)", fontsize=16)
plt.tight_layout()
plt.show()

Each head learns different patterns:

  • Head 1: Local
  • Head 2: Global
  • Head 3: Syntax
  • etc.

9. Efficiency: Memory & Compute

OperationTimeMemory
Linear Projections$ O(N d^2) $$ O(N d) $
Split Heads$ O(N d) $$ O(N d) $
Attention (per head)$ O(N^2 d/h) $$ O(N^2) $
Total$ O(N^2 d) $$ O(N^2 + N d) $

Same complexity as single head, but richer output


10. Divide & Conquer Intuition

Single Head (64-dim):
"the cat sat on the mat"
       └────┬────┘
           One view

Multi-Head (8 × 8-dim):
"the cat sat on the mat"
 ├─> "the" ↔ pronouns
 ├─> "cat" ↔ animals
 ├─> "sat" ↔ verbs
 └─> "on" ↔ prepositions

Each head specializesemergent behavior


11. Full Transformer Block with Self-Attention

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-Attention + Residual
        attn_out, attn_weights = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Feed Forward + Residual
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        
        return x, attn_weights

12. Test: Multi-Head vs Single Head

x = torch.randn(1, 32, 512)

mha_8 = MultiHeadAttention(512, 8)
mha_1 = MultiHeadAttention(512, 1)

out_8, _ = mha_8(x, x, x)
out_1, _ = mha_1(x, x, x)

print("8 heads output norm:", out_8.norm().item())
print("1 head output norm:", out_1.norm().item())

8 heads → richer, more stable representations


13. Summary Cheat Sheet

ConceptValue
Self-AttentionQ = K = V = X
Multi-Head$ h $ parallel attention layers
Head Dim$ d_k = d_{\text{model}} / h $
Split.view(B, N, h, d_k).transpose(1,2)
Combine.transpose(1,2).view(B, N, d)
ParallelismGPU runs all heads at once
Complexity$ O(N^2 d) $ (same as single)

14. Practice Exercises

  1. Ablate: Train with 1 vs 8 heads → compare performance on copy task.
  2. Visualize: Plot attention for each head on real sentences.
  3. Efficiency: Measure time for num_heads=1, 8, 16.
  4. Custom: Implement grouped-query attention (MQA).
  5. Debug: Add print(shape) in forward() to trace tensor dims.

15. Key Takeaways

CheckInsight
CheckSelf-Attention = intra-sequence communication
CheckMulti-Head = parallel feature extractors
CheckDivide & Conquer = split embedding space
CheckSame cost, better performance
CheckEnables specialization

Final Words

You just built the brain of every modern LLM.
Multi-Head Self-Attention = parallel, rich, scalable context.


Full Copy-Paste Code

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model, self.h, self.d_k = d_model, num_heads, d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        B = Q.shape[0]
        Q, K, V = self.W_q(Q), self.W_k(K), self.W_v(V)
        Q = Q.view(B, -1, self.h, self.d_k).transpose(1, 2)
        K = K.view(B, -1, self.h, self.d_k).transpose(1, 2)
        V = V.view(B, -1, self.h, self.d_k).transpose(1, 2)
        
        scores = (Q @ K.transpose(-2, -1)) / self.scale
        if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = (attn @ V).transpose(1, 2).contiguous().view(B, -1, self.d_model)
        return self.W_o(out), attn

End of Module
You now control parallel attention — the heart of GPT, BERT, and beyond.
Go stack 100 layers.