Complete Mistral Transformer Architecture in PyTorch
Tech3Space03 Jun 2026
✅ Complete Mistral Transformer Architecture in PyTorch
Here's a clean, well-commented implementation of the Mistral model family (inspired by Mistral-7B architecture).
mistral_transformer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ============================
# 1. RMSNorm (Used in Mistral instead of LayerNorm)
# ============================
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
# ============================
# 2. Rotary Embeddings (RoPE)
# ============================
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Compute theta
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_len=None):
# x shape: (batch_size, num_heads, seq_len, head_dim)
seq_len = x.shape[-2] if seq_len is None else seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim/2)
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
cos = emb.cos()[None, None, :, :] # (1, 1, seq_len, dim)
sin = emb.sin()[None, None, :, :]
return cos, sin
def rotate_half(x):
"""Rotate half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_emb(q, k, cos, sin):
# q, k: (batch, heads, seq_len, head_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# ============================
# 3. Grouped Query Attention (GQA) - Core of Mistral
# ============================
class MistralAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config['hidden_size']
self.num_heads = config['num_attention_heads']
self.num_kv_heads = config['num_key_value_heads'] # GQA
self.head_dim = config['head_dim']
self.max_seq_len = config.get('max_seq_len', 4096)
self.sliding_window = config.get('sliding_window', 4096)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim, self.max_seq_len)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# Project Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Rotary Embeddings
cos, sin = self.rotary_emb(q, seq_len)
q, k = apply_rotary_emb(q, k, cos, sin)
# Repeat KV heads for GQA
if self.num_kv_heads != self.num_heads:
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
# Scaled Dot Product Attention
scale = 1.0 / math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# Apply sliding window / causal mask
if mask is None:
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
scores = scores + mask
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
output = self.o_proj(attn_output)
return output
# ============================
# 4. SwiGLU Feed Forward
# ============================
class MistralMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
# ============================
# 5. Mistral Decoder Layer
# ============================
class MistralDecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config['hidden_size']
self.self_attn = MistralAttention(config)
self.mlp = MistralMLP(config['hidden_size'], config['intermediate_size'])
self.input_layernorm = RMSNorm(config['hidden_size'])
self.post_attention_layernorm = RMSNorm(config['hidden_size'])
def forward(self, x, mask=None):
# Self Attention
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, mask)
x = residual + x
# Feed Forward
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
# ============================
# 6. Complete Mistral Model
# ============================
class MistralModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.vocab_size = config['vocab_size']
self.hidden_size = config['hidden_size']
self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size)
self.layers = nn.ModuleList([
MistralDecoderLayer(config) for _ in range(config['num_hidden_layers'])
])
self.norm = RMSNorm(self.hidden_size)
def forward(self, input_ids, attention_mask=None):
# input_ids: (batch_size, seq_len)
x = self.embed_tokens(input_ids)
# Create causal mask
seq_len = input_ids.shape[1]
mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
for layer in self.layers:
x = layer(x, mask)
x = self.norm(x)
return x
# ============================
# Configuration for Mistral-7B-like model
# ============================
mistral_config = {
"vocab_size": 32000,
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 8, # Grouped Query Attention
"head_dim": 128,
"max_seq_len": 4096,
"sliding_window": 4096,
"rms_norm_eps": 1e-5
}
# ============================
# Test the Model
# ============================
if __name__ == "__main__":
model = MistralModel(mistral_config)
# Dummy input
input_ids = torch.randint(0, mistral_config['vocab_size'], (2, 512)) # batch=2, seq=512
output = model(input_ids)
print("✅ Mistral Model Created Successfully!")
print(f"Input shape : {input_ids.shape}")
print(f"Output shape: {output.shape}") # (batch, seq_len, hidden_size)
Key Features of Mistral Family Implemented:
- RMSNorm instead of LayerNorm
- Rotary Embeddings (RoPE)
- Grouped Query Attention (GQA) — 8 KV heads for 32 Q heads
- SwiGLU activation in FFN
- Sliding Window ready (mask can be modified)
- Efficient architecture matching Mistral-7B
How to run:
python mistral_transformer.py
Would you like me to add:
- Generation loop (text generation)?
- LoRA / QLoRA support?
- Mixtral (MoE) version?
- Loading weights from Hugging Face?
Just tell me!