Transformer Architecture: A Detailed Note
Transformer is the core architecture for modern NLP and large language model systems. This note explains structure, equations, training, and engineering optimizations.
1. Why Transformer
RNN/LSTM models have natural bottlenecks for long-range dependency modeling and parallel training:
- Long gradient propagation paths
- Strong step-by-step sequential dependency in computation
Transformer addresses both with attention-based global interaction and hardware-friendly parallelism.
2. Macro Architecture
Common forms:
- Encoder-only (e.g., BERT) for understanding tasks
- Decoder-only (e.g., GPT) for autoregressive generation
- Encoder-Decoder (e.g., original Transformer, T5) for conditional generation
For modern autoregressive LLMs, Decoder-only is the dominant choice.
flowchart LR
A[Token IDs] --> B[Embedding + Positional Info]
B --> C[Transformer Block x N]
C --> D[LayerNorm]
D --> E[Linear Head]
E --> F[Next-token logits]
A typical block contains:
- Multi-Head Self-Attention
- Feed-Forward Network (FFN)
- Residual connections
- Normalization (LayerNorm or RMSNorm)
3. Input Representation: Embedding + Position
Token embeddings map ids to vectors:
Without positional information, sequence order is lost. Common positional mechanisms:
- Sinusoidal absolute position encoding
- Learnable position embedding
- RoPE (widely used in modern LLMs)
Sinusoidal form:
4. Self-Attention
From input
Scaled dot-product attention:
where
- Padding mask ignores padded positions
- Causal mask blocks future tokens in autoregressive decoding
5. Multi-Head Attention
Multiple heads capture different relational subspaces:
In practice, different heads often focus on syntax, entities, locality, and long-range semantics.
6. Feed-Forward Network
Attention mixes information across tokens; FFN applies non-linear transformation per position:
Modern variants often use SwiGLU/GELU for stronger efficiency-quality tradeoffs.
7. Residual and Normalization
Pre-Norm block form:
Pre-Norm is typically more stable than Post-Norm in very deep stacks.
8. Training and Inference
8.1 Autoregressive Objective
For decoder-only LMs:
8.2 KV Cache in Decoding
At generation time, each step adds one token. Caching historical K/V avoids recomputing all previous states:
- Reduces repeated compute
- Improves long-context throughput
9. Common Engineering Upgrades
- Position handling: RoPE / ALiBi
- Normalization: LayerNorm -> RMSNorm
- FFN activation: ReLU/GELU -> SwiGLU
- Attention kernel: FlashAttention
- KV efficiency: MQA / GQA
- Stability: warmup, clipping, weight decay