Transformer Architecture: Attention Is All You Need
Definition
The Transformer architecture, introduced by Vaswani et al. in the seminal 2017 paper 'Attention Is All You Need', revolutionized deep learning by replacing recurrence and convolution entirely with attention mechanisms. Transformers process sequences in parallel rather than sequentially, enabling efficient training on massive datasets and leading to breakthrough models like GPT, BERT, and T5. The architecture consists of an encoder (for understanding input) and decoder (for generating output), each built from stacked identical layers. Core components are multi-head self-attention (allowing each position to attend to all positions), position-wise feed-forward networks, residual connections, and layer normalization. Modern Large Language Models (LLMs) like GPT-4, Claude, and Llama are based on decoder-only Transformer variants, while encoder-only models (BERT) excel at understanding tasks and encoder-decoder models (T5) handle sequence-to-sequence tasks.
Intuition
Imagine a group of experts in a meeting where everyone can simultaneously listen to everyone else, weighting each person's contribution by relevance. Traditional RNNs are like a single-file line where each person only talks to the person in front of them - slow and limited context. Transformers are like that meeting: every token (word/subword) can directly connect to every other token, with attention scores determining 'how much should I listen to you?' Multi-head attention is like having multiple meetings in parallel, each focusing on different aspects - one meeting tracks syntax, another tracks semantics, another tracks pronoun references. The feed-forward layers are like each expert privately processing what they learned from the meeting before the next round. Position encodings give each participant a name tag saying 'I am word #5', because unlike RNNs, Transformers process words simultaneously and need to know their order.
Mathematical Formula
Step-by-Step Explanation:
- Scaled Dot-Product: Query-Key dot products scaled by \(1/\sqrt{d_k}\) for stable softmax, multiplied by Values
- Multi-Head: Projects Q,K,V into h subspaces, applies attention in parallel, concatenates results
- Feed-Forward: Two-layer MLP with ReLU activation applied position-wise (same network for each position)
- LayerNorm: Normalizes across feature dimension, then learns scale \(\gamma\) and shift \(eta\) parameters
- Positional Encoding: Adds position information through sinusoidal functions of varying frequencies
- Complexity: Quadratic in sequence length n, linear in model dimension d - main computational bottleneck
Real-World Use Cases
GPT-4, Claude, Llama using decoder-only Transformers for text generation
Google Translate using Transformer encoder-decoder for 100+ languages
GitHub Copilot generating code completions using GPT-style models
Vision Transformers (ViT) achieving SOTA on ImageNet without convolutions
AlphaFold2 using Transformers for protein structure prediction
CLIP and GPT-4V processing both text and images with shared Transformer architecture
Implementation
Manual Implementation (No Libraries)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, V)
return output, attn_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear projections and reshape for multi-head
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads and apply final linear
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_o(attn_output)
return output, attn_weights
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.linear1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attn_output, _ = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length=5000, dropout=0.1):
super(TransformerEncoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
self.layers = nn.ModuleList([
TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(d_model)
def forward(self, x, mask=None):
x = self.embedding(x) * self.scale
x = self.pos_encoding(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, mask)
return x
# Test
vocab_size = 10000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048\
encoder = TransformerEncoder(vocab_size, d_model, num_heads, num_layers, d_ff)
src = torch.randint(0, vocab_size, (2, 20)) # batch_size=2, seq_len=20
output = encoder(src)
print(f'Output shape: {output.shape}') # [2, 20, 512]
Using Libraries (torch, torch.nn, transformers, tensorflow, keras)
import torch
import torch.nn as nn
import torch.nn.functional as F
# Using PyTorch's built-in Transformer
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, num_classes, dropout=0.1):
super(TransformerClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, dropout)
encoder_layers = nn.TransformerEncoderLayer(d_model, num_heads, dim_feedforward=4*d_model,
dropout=dropout, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
self.fc = nn.Linear(d_model, num_classes)
self.d_model = d_model
def forward(self, src, src_mask=None):
src = self.embedding(src) * math.sqrt(self.d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_mask)
# Global average pooling
output = output.mean(dim=1)
output = self.fc(output)
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(1), :].transpose(0, 1)
return self.dropout(x)
# GPT-style Decoder-only Transformer
class GPTDecoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, max_seq_len=512, dropout=0.1):
super(GPTDecoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Embedding(max_seq_len, d_model)
decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads, dim_feedforward=4*d_model,
dropout=dropout, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
self.ln_f = nn.LayerNorm(d_model)
self.head = nn.Linear(d_model, vocab_size, bias=False)
self.max_seq_len = max_seq_len
self.d_model = d_model
def forward(self, idx, targets=None):
b, t = idx.size()
# Token + positional embeddings
tok_emb = self.embedding(idx)
pos_emb = self.pos_embedding(torch.arange(t, device=idx.device))
x = tok_emb + pos_emb
# Causal mask
causal_mask = torch.triu(torch.ones(t, t, device=idx.device), diagonal=1).bool()
# Transformer decoder (with memory=None for autoregressive)
x = self.transformer_decoder(x, memory=torch.zeros((b, 1, self.d_model), device=idx.device), \
tgt_mask=~causal_mask)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
# Hugging Face Transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BertModel, BertTokenizer
# Load pre-trained GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Generate text
text = 'The future of AI is'
inputs = tokenizer(text, return_tensors='pt')
outputs = model.generate(**inputs, max_length=50, num_return_sequences=1)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# BERT for encoding
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
encoded = bert_tokenizer('Hello world', return_tensors='pt')
bert_output = bert_model(**encoded)
print(f'Last hidden state shape: {bert_output.last_hidden_state.shape}')
# TensorFlow/Keras
import tensorflow as tf
def create_transformer_encoder(vocab_size, d_model, num_heads, num_layers, seq_len):
inputs = tf.keras.Input(shape=(seq_len,))
# Embedding + positional encoding
embedding = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)
positions = tf.keras.layers.Embedding(seq_len, d_model)(tf.range(seq_len))
x = embedding + positions
# Transformer encoder layers
for _ in range(num_layers):
# Multi-head attention
attn_output = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model//num_heads)(x, x)
x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x + attn_output)
# Feed-forward
ff_output = tf.keras.layers.Dense(d_model * 4, activation='relu')(x)
ff_output = tf.keras.layers.Dense(d_model)(ff_output)
x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x + ff_output)
return tf.keras.Model(inputs, x)
model_tf = create_transformer_encoder(10000, 512, 8, 6, 128)
When to Use
✅ Appropriate Use Cases:
- Sequential data where long-range dependencies matter
- Tasks requiring parallel processing of sequences (faster than RNNs)
- Large-scale language modeling and text generation
- When you have sufficient compute for quadratic attention complexity
- Multi-modal tasks combining text, images, audio
- Transfer learning with pre-trained models (BERT, GPT, T5)
❌ Avoid When:
- Very long sequences (>10k tokens) where O(n²) attention is prohibitive
- Resource-constrained environments (use RNNs or linear attention variants)
- When strict causality must be enforced (use masked attention carefully)
- Small datasets where RNNs with strong inductive bias perform better
- Real-time low-latency applications (attention computation overhead)
- When interpretability requires understanding of local feature hierarchies
Common Pitfalls
- Forgetting causal masking in autoregressive decoders (leaks future information)
- Not scaling attention by \(1/\sqrt{d_k}\) causing gradient instability
- Using absolute position embeddings when relative positions matter more
- Insufficient gradient clipping causing training divergence in deep models
- Not handling variable-length sequences with proper padding and masking
- Incorrect key/query/value dimensions breaking multi-head attention
- Layer norm placement (pre-norm vs post-norm) affecting training stability
- Attention dropout applied after softmax breaking probability distribution