Loading...
Development

Complete Module: Big-O, Parallelism, FlashAttention, LoRA

Scaling Laws & Optimization

Complete Module: Big-O, Parallelism, FlashAttention, LoRA


Module Objective

Master LLM scalingChinchilla laws, Big-O complexity, GPU parallelism, FlashAttention, LoRA — with math, code, and 100x efficiency gains.


1. Chinchilla Scaling Laws (2022)

"Optimal training: balance model size and data"

# Optimal parameters for given compute
def chinchilla_optimal_params(compute):
    return 0.074 * compute**0.73  # ~70B for 1.4T tokens

# Optimal tokens
def chinchilla_optimal_tokens(compute):
    return 19.3 * compute**0.27  # ~1.4T for 70B model
ModelParamsTokensComputeUndertrained?
GPT-3175B300B3.7e23Yes
Chinchilla70B1.4T3.7e23Optimal

Result: 70B > 175B on same compute


2. Big-O Complexity of Transformers

OperationTimeMemory
Attention$ O(N^2 d) $$ O(N^2) $
FFN$ O(N d^2) $$ O(N d) $
Total per layer$ O(N^2 d + N d^2) $$ O(N^2 + N d) $
L layers$ O(L N^2 d) $$ O(L N^2) $

Bottleneck: $ N^2 $ attention matrix


3. GPU Parallelism: Data, Tensor, Pipeline

Data Parallel (DP): 
  8 GPUs → 8x batch → same model

Tensor Parallel (TP): 
  Layer split across 4 GPUs → W_q on GPU0, W_k on GPU1

Pipeline Parallel (PP): 
  Layers 1–4 on GPU0, 5–8 on GPU1

Megatron-LM: TP + PP → 1T params


4. FlashAttention: O(N) Memory, 2–4x Faster

Problem: Standard Attention

attn = softmax(Q @ K.T / sqrt(d)) @ V
# → Materialize N×N matrix → O(N²) memory

FlashAttention: No materialization

# Online softmax + tiling
for i in blocks:
    Q_block = Q[i]
    for j in blocks:
        K_block, V_block = K[j], V[j]
        S = Q_block @ K_block.T
        P = softmax(S)
        O += P @ V_block

Memory: $ O(N) $
Speed: 2–4x faster, 15% less memory

from flash_attn import flash_attention

attn_output = flash_attention(q, k, v, causal=True)

5. LoRA: Train 0.1% of Parameters

"Freeze weights, train low-rank adapters"

W = W₀ + ΔW
ΔW = B A    # B: (d, r), A: (r, k) → r << d

LoRA Injection

class LoRALinear(nn.Module):
    def __init__(self, linear, rank=8):
        super().__init__()
        self.linear = linear
        d = linear.in_features
        self.A = nn.Parameter(torch.randn(rank, d) * 0.01)
        self.B = nn.Parameter(torch.zeros(d, rank))
        
    def forward(self, x):
        return self.linear(x) + (x @ self.A.T @ self.B.T)

Params: $ 2 r d $ vs $ d k $
Example: $ d=4096, r=8 → 0.2% $ of weight


6. Full LoRA + FlashAttention Training

from transformers import AutoModelForCausalLM
import peft
from flash_attn import flash_attention

# Load base model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Add LoRA
lora_config = peft.LoraConfig(
    r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"], lora_dropout=0.1
)
model = peft.get_peft_model(model, lora_config)

# Use FlashAttention
def forward_with_flash(self, x):
    q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
    return flash_attention(q, k, v, causal=True)

# Monkey patch
model.transformer.h[0].attn.forward = forward_with_flash

7. Compute Scaling: FLOPs

def transformer_flops(batch, seq_len, d_model, layers, vocab):
    # Embedding
    flops = batch * seq_len * vocab * d_model
    
    # Per layer
    attn = 2 * batch * seq_len**2 * d_model
    ffn = 8 * batch * seq_len * d_model**2
    flops += layers * (attn + ffn)
    
    # Output
    flops += batch * seq_len * d_model * vocab
    return flops

print(f"GPT-3 175B: {transformer_flops(1, 2048, 12288, 96, 50257):.2e} FLOPs")
# → 3.7e23 FLOPs

8. Parallelism in Code

# Tensor Parallel (simplified)
class TensorParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features//world_size, in_features))
        self.world_size = world_size
        
    def forward(self, x):
        out = x @ self.weight.t()
        # All-gather across GPUs
        return all_gather(out, dim=-1)

9. Optimization: AdamW + Gradient Checkpointing

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

# Gradient Checkpointing: trade compute for memory
model = torch.utils.checkpoint.checkpoint_sequential(model, segments=4)

10. Summary Table

TechniqueSpeedMemoryParams Trained
FlashAttention2–4x-80%100%
LoRA1x-99%0.1%
Tensor Parallel8x (8 GPUs)100%
Gradient Checkpoint0.7x-70%100%

11. Practice Exercises

  1. Train LoRA on TinyShakespeare
  2. Benchmark FlashAttention vs standard
  3. Plot Chinchilla curve
  4. Implement pipeline parallelism
  5. Combine LoRA + FlashAttention

12. Key Takeaways

CheckInsight
CheckChinchilla: 70B > 175B
CheckAttention = O(N²)
CheckFlashAttention = O(N) memory
CheckLoRA = 0.1% trainable params
CheckScale efficiently

Full Copy-Paste: LoRA + FlashAttention

!pip install flash-attn peft transformers

import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from flash_attn import flash_attention

# Load model
model = AutoModelForCausalLM.from_pretrained("gpt2")

# LoRA
config = LoraConfig(r=8, lora_alpha=32, target_modules=["c_attn", "c_proj"])
model = get_peft_model(model, config)

# Replace attention
def flash_forward(self, x):
    q, k, v = self.c_attn(x).chunk(3, dim=-1)
    return self.c_proj(flash_attention(q, k, v, causal=True))

# Patch first layer
model.transformer.h[0].attn.forward = flash_forward.__get__(model.transformer.h[0].attn)

print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Final Words

You now train LLMs like DeepMind, Meta, Google.

  • Chinchilla-optimal
  • FlashAttention-fast
  • LoRA-efficient
  • Scalable to 1T

End of Module
You scale like the pros — efficient, fast, optimal.
Next: Build a 7B model.