Transformers in PyTorch: Practical Guide with Code Examples & Deep Understanding
Transformers in PyTorch: Practical Guide with Code Examples & Deep Understanding
Learn Transformers from scratch with clean PyTorch code, intuitive explanations, and visual diagrams.
Transformers are the foundation of modern AI — powering models like GPT, BERT, LLaMA, and Stable Diffusion. This blog will give you strong conceptual clarity and practical coding skills through well-formatted, easy-to-follow PyTorch examples.
Why Learn Transformers?
- Parallel processing — Much faster training than RNNs/LSTMs
- Excellent at long-range dependencies
- Highly scalable architecture
- Versatile — Works for text, images (ViT), audio, and more
Understanding Query, Key, Value (QKV) is the key to mastering Transformers.
1. Scaled Dot-Product Attention: The Core Idea
Analogy: Imagine you're reading a sentence. When processing the word "it", you need to find which previous word it refers to ("the cat" or "the dog").
- Query (Q) → What the current token is searching for
- Key (K) → What each token "advertises" about itself
- Value (V) → The actual information/content of each token
Formula: $$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$
Visual Explanation of QKV
PyTorch Code: Basic Attention Module
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k: int):
super().__init__()
self.d_k = d_k
def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask=None):
# Q, K, V shape: (batch_size, num_heads, seq_len, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1) # Shape: (batch, heads, seq, seq)
output = torch.matmul(attn_weights, V)
return output, attn_weights
# Test
Q = torch.randn(32, 8, 10, 64) # batch, heads, seq_len, d_k
K = torch.randn(32, 8, 10, 64)
V = torch.randn(32, 8, 10, 64)
attn = ScaledDotProductAttention(d_k=64)
output, weights = attn(Q, K, V)
print(output.shape) # torch.Size([32, 8, 10, 64])
2. Multi-Head Attention
Running multiple attention heads in parallel allows the model to focus on different types of relationships (syntax, semantics, etc.).
PyTorch Code: Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(self.d_k)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Linear projections
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
attn_output, attn_weights = self.attention(Q, K, V, mask)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# Final linear layer
output = self.W_o(attn_output)
return output, attn_weights
3. Positional Encoding
Transformers have no recurrence, so we add positional information to token embeddings.
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
4. Feed-Forward Network
Applied after attention in each layer.
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.net(x)
5. Encoder Layer (Full Component)
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int = 2048, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ff = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self Attention + Residual
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed Forward + Residual
ff_output = self.ff(x)
x = self.norm2(x + self.dropout(ff_output))
return x
6. Complete Transformer (Encoder + Decoder)
For simplicity, here's a basic Transformer Encoder (used in BERT-style models):
class TransformerEncoder(nn.Module):
def __init__(self, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(10000, d_model) # vocab size = 10000
self.pos_encoding = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, src, mask=None):
x = self.embedding(src)
x = self.pos_encoding(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, mask)
return x
# Usage
model = TransformerEncoder(d_model=512, num_layers=6)
src = torch.randint(0, 10000, (32, 50)) # batch, seq_len
output = model(src)
print(output.shape) # (32, 50, 512)
Summary: Main Components of Transformer
| Component | Purpose | Key Code Class |
|---|---|---|
| Scaled Dot-Product Attention | Core attention mechanism | ScaledDotProductAttention |
| Multi-Head Attention | Multiple attention subspaces | MultiHeadAttention |
| Positional Encoding | Add position information | PositionalEncoding |
| Feed Forward | Non-linear transformation | FeedForward |
| Encoder Layer | Full encoder block | EncoderLayer |
| LayerNorm + Residual | Stable training | Built into layers |