SwiGLU: Swish-Gated Linear Unit
Definition
SwiGLU (Swish-Gated Linear Unit) is a modern activation function introduced by Shazeer in the GLU Variants paper (2020) and popularized by the PaLM and Llama models. It combines the Swish (SiLU) activation with a gating mechanism inspired by LSTM gates. SwiGLU is defined as Swish(xW) ⊙ (xV), where two different linear projections are applied to the input, one passes through a Swish activation and acts as a gate for the other. This gating mechanism allows the network to learn which parts of the input to pass through and which to suppress. SwiGLU has been shown to consistently outperform traditional activations like ReLU and GELU in large language models, with PaLM reporting improved training stability and final performance. The activation is particularly effective in the feed-forward layers of Transformers.
Intuition
Imagine a security checkpoint with two guards. Guard A (the Swish gate) decides how much to open the gate based on what they see. Guard B (the linear projection) presents the actual items to pass through. The final output is what Guard B presents, scaled by how much Guard A decides to open the gate. If Guard A sees something important, they open the gate fully (multiply by ~1); if they see noise, they close it (multiply by ~0). This is SwiGLU: the Swish activation acts as a learned gate that selectively passes information from the linear projection. Unlike ReLU which is either on or off (like a binary gate), Swish provides smooth gating - it can be partially open, allowing for more nuanced information flow. The two projections (W and V) learn different transformations - one focuses on 'what to gate' and the other on 'what to pass.' This dual-path architecture gives the network more expressive power than simple element-wise activations.
Mathematical Formula
Step-by-Step Explanation:
- Swish: Smooth activation that is x · sigmoid(x), providing non-linearity with unbounded positive range
- GLU: Gating mechanism using sigmoid gate to modulate linear projection
- SwiGLU: Combines Swish gating with linear projection; note no bias on second projection in modern variants
- FFN SwiGLU: Transformer feed-forward using SwiGLU as activation (replaces ReLU/GELU)
- Swish Derivative: Combines sigmoid contribution with input-scaled contribution for smooth gradients
- General Gating: Learned gate g multiplies filtered input x̃, allowing selective information flow
Real-World Use Cases
PaLM (540B parameters) using SwiGLU in feed-forward layers for improved training
Llama 2 and Llama 3 adopting SwiGLU as the default activation
Mixtral using SwiGLU in expert layers for better gating
EVA-CLIP and SigLIP using SwiGLU for vision-language pretraining
CodeLlama leveraging SwiGLU for programming language understanding
BLOOM and multilingual PaLM using SwiGLU for cross-lingual transfer
Implementation
Manual Implementation (No Libraries)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class Swish(nn.Module):
def __init__(self, beta=1.0):
super().__init__()
self.beta = beta
def forward(self, x):
return x * torch.sigmoid(self.beta * x)
class SwiGLU(nn.Module):
def __init__(self, dim_input, dim_hidden, bias=True):
super().__init__()
# Two linear projections: one for gating, one for value
self.W_gate = nn.Linear(dim_input, dim_hidden, bias=bias)
self.W_value = nn.Linear(dim_input, dim_hidden, bias=False) # No bias on value in modern implementations
def forward(self, x):
# Gate pathway with Swish activation
gate = self.W_gate(x)
gate = F.silu(gate) # SiLU is equivalent to Swish with beta=1
# Value pathway (linear)
value = self.W_value(x)
# Gating: element-wise multiplication
output = gate * value
return output
class FeedForwardSwiGLU(nn.Module):
def __init__(self, dim, hidden_dim=None, multiple_of=256, dropout=0.0):
super().__init__()
# Hidden dim is typically 2.66x input dim in SwiGLU (to maintain parameter count)
# Standard: hidden_dim = 4 * dim for GELU, ~8/3 * dim for SwiGLU
if hidden_dim is None:
hidden_dim = 4 * dim
# Adjust to multiple_of for hardware efficiency
hidden_dim = int(2 * hidden_dim / 3) # SwiGLU uses 2/3 for comparable params
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # Gate projection
self.w2 = nn.Linear(dim, hidden_dim, bias=False) # Value projection
self.w3 = nn.Linear(hidden_dim, dim, bias=False) # Output projection
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# SwiGLU: Swish(w1(x)) * w2(x)
gate = F.silu(self.w1(x)) # SiLU/Swish activation
hidden = gate * self.w2(x) # Gating
output = self.w3(hidden)
output = self.dropout(output)
return output
# NumPy implementation for understanding gradients
class SwiGLUNumPy:
def __init__(self, input_dim, hidden_dim):
self.input_dim = input_dim
self.hidden_dim = hidden_dim
# Xavier initialization
self.W_gate = np.random.randn(input_dim, hidden_dim) * np.sqrt(2.0 / input_dim)
self.W_value = np.random.randn(input_dim, hidden_dim) * np.sqrt(2.0 / input_dim)
self.b_gate = np.zeros(hidden_dim)
def swish(self, x):
return x * (1 / (1 + np.exp(-np.clip(x, -500, 500))))
def swish_derivative(self, x):
sigmoid = 1 / (1 + np.exp(-np.clip(x, -500, 500)))
return sigmoid + x * sigmoid * (1 - sigmoid)
def forward(self, x):
# Linear projections
z_gate = np.dot(x, self.W_gate) + self.b_gate
z_value = np.dot(x, self.W_value)
# Swish activation on gate
gate = self.swish(z_gate)
# Gating
output = gate * z_value
cache = (x, z_gate, z_value, gate)
return output, cache
def backward(self, grad_output, cache, learning_rate=0.01):
x, z_gate, z_value, gate = cache
# Gradient through gating (element-wise multiplication)
grad_gate = grad_output * z_value
grad_z_value = grad_output * gate
# Gradient through Swish
grad_z_gate = grad_gate * self.swish_derivative(z_gate)
# Gradients w.r.t. weights
grad_W_gate = np.dot(x.T, grad_z_gate)
grad_b_gate = np.sum(grad_z_gate, axis=0)
grad_W_value = np.dot(x.T, grad_z_value)
# Update weights
self.W_gate -= learning_rate * grad_W_gate
self.b_gate -= learning_rate * grad_b_gate
self.W_value -= learning_rate * grad_W_value
# Gradient w.r.t. input for backprop
grad_x = np.dot(grad_z_gate, self.W_gate.T) + np.dot(grad_z_value, self.W_value.T)
return grad_x
# Visualization and comparison
def visualize_activations():
x = torch.linspace(-5, 5, 1000)
# Different activations
relu = F.relu(x)
gelu = F.gelu(x)
swish = x * torch.sigmoid(x)
sigmoid = torch.sigmoid(x)
# Plotting
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Activations
ax = axes[0, 0]
ax.plot(x.numpy(), relu.numpy(), label='ReLU', linewidth=2)
ax.plot(x.numpy(), gelu.numpy(), label='GELU', linewidth=2)
ax.plot(x.numpy(), swish.numpy(), label='Swish', linewidth=2)
ax.set_title('Activation Functions')
ax.legend()
ax.grid(True)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
ax.axvline(x=0, color='k', linestyle='--', alpha=0.3)
# Derivatives
ax = axes[0, 1]
x_grad = x[:-1] + 0.5 * (x[1] - x[:-1])
relu_grad = torch.diff(relu) / torch.diff(x)
gelu_grad = torch.diff(gelu) / torch.diff(x)
swish_grad = torch.diff(swish) / torch.diff(x)
ax.plot(x_grad.numpy(), relu_grad.numpy(), label='ReLU'', linewidth=2)
ax.plot(x_grad.numpy(), gelu_grad.numpy(), label='GELU'', linewidth=2)
ax.plot(x_grad.numpy(), swish_grad.numpy(), label='Swish'', linewidth=2)
ax.set_title('Derivatives')
ax.legend()
ax.grid(True)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
ax.axvline(x=0, color='k', linestyle='--', alpha=0.3)
# SwiGLU behavior visualization
ax = axes[1, 0]
swiglu = SwiGLU(1, 1)
x_2d = torch.linspace(-3, 3, 100).unsqueeze(1)
with torch.no_grad():
# Manual SwiGLU computation
gate = torch.sigmoid(x_2d) * x_2d # Swish
value = x_2d # Identity for visualization
output = gate * value
ax.plot(x_2d.numpy(), x_2d.numpy(), label='Input', alpha=0.5)
ax.plot(x_2d.numpy(), gate.numpy(), label='Gate (Swish)', alpha=0.7)
ax.plot(x_2d.numpy(), output.numpy(), label='SwiGLU Output', linewidth=2)
ax.set_title('SwiGLU Gating Behavior')
ax.legend()
ax.grid(True)
# Smoothness comparison
ax = axes[1, 1]
ax.plot(x.numpy(), np.abs(np.diff(relu.numpy(), prepend=relu[0].numpy())), label='ReLU |Δ|', alpha=0.7)
ax.plot(x.numpy(), np.abs(np.diff(gelu.numpy(), prepend=gelu[0].numpy())), label='GELU |Δ|', alpha=0.7)
ax.plot(x.numpy(), np.abs(np.diff(swish.numpy(), prepend=swish[0].numpy())), label='Swish |Δ|', alpha=0.7)
ax.set_title('Smoothness (Change Magnitude)')
ax.legend()
ax.grid(True)
plt.tight_layout()
plt.savefig('swiglu_visualization.png', dpi=150)
print('Visualization saved to swiglu_visualization.png')
# Test implementations
print('Testing SwiGLU implementations...')
batch_size = 4
seq_len = 128
dim = 512
x = torch.randn(batch_size, seq_len, dim)
# Test SwiGLU
swiglu = SwiGLU(dim, dim * 2)
output = swiglu(x)
print(f'SwiGLU input shape: {x.shape}')
print(f'SwiGLU output shape: {output.shape}')
# Test FeedForward SwiGLU
ffn = FeedForwardSwiGLU(dim)
output_ffn = ffn(x)
print(f'
FeedForward SwiGLU input shape: {x.shape}')
print(f'FeedForward SwiGLU output shape: {output_ffn.shape}')
# NumPy version
swiglu_np = SwiGLUNumPy(dim, dim * 2)
x_np = x[0, 0].numpy().reshape(1, -1)
output_np, cache = swiglu_np.forward(x_np)
print(f'
NumPy SwiGLU output shape: {output_np.shape}')
# Gradient test
grad_output = np.ones_like(output_np)
grad_x = swiglu_np.backward(grad_output, cache)
print(f'Gradient shape: {grad_x.shape}')
# Parameter count comparison
ffn_relu = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.ReLU(),
nn.Linear(4 * dim, dim)
)
ffn_gelu = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim)
)
ffn_swiglu = FeedForwardSwiGLU(dim, hidden_dim=int(8 * dim / 3))
print(f'
Parameter comparison:')
print(f'FFN ReLU: {sum(p.numel() for p in ffn_relu.parameters())/1e6:.2f}M')
print(f'FFN GELU: {sum(p.numel() for p in ffn_gelu.parameters())/1e6:.2f}M')
print(f'FFN SwiGLU: {sum(p.numel() for p in ffn_swiglu.parameters())/1e6:.2f}M')
Using Libraries (torch, torch.nn, transformers, tensorflow, jax, flax)
import torch
import torch.nn as nn
import torch.nn.functional as F
# Standard PyTorch SwiGLU (used in Llama, PaLM)
class LlamaSwiGLU(nn.Module):
def __init__(self, dim, hidden_dim=None, multiple_of=256, dropout=0.0):
super().__init__()
# Llama uses 2/3 of 4*dim for SwiGLU to match parameter count with standard FFN
if hidden_dim is None:
hidden_dim = 4 * dim
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # Gate projection
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # Down projection
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # Value projection
def forward(self, x):
# Llama FFN: w2(silu(w1(x)) * w3(x))
return self.w2(F.silu(self.w1(x)) * self.w3(x))
# Hugging Face Transformers integration
try:
from transformers.activations import ACT2FN
from transformers.models.llama.modeling_llama import LlamaMLP
# Llama MLP with SwiGLU
llama_mlp = LlamaMLP(config=None, hidden_size=512, intermediate_size=1376)
print('HuggingFace Llama MLP loaded')
except ImportError:
print('HuggingFace transformers not available')
# Custom SwiGLU variants
class SwiGLUVariant(nn.Module):
def __init__(self, variant='swiglu', dim=512, hidden_dim=1376):
super().__init__()
self.variant = variant
self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
self.w_value = nn.Linear(dim, hidden_dim, bias=False)
self.w_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
if self.variant == 'swiglu':
# Swish(xW) * xV
gate = F.silu(self.w_gate(x))
elif self.variant == 'geglu':
# GELU(xW) * xV
gate = F.gelu(self.w_gate(x))
elif self.variant == 'reglu':
# ReLU(xW) * xV
gate = F.relu(self.w_gate(x))
elif self.variant == 'glu':
# sigmoid(xW) * xV
gate = torch.sigmoid(self.w_gate(x))
else:
raise ValueError(f'Unknown variant: {self.variant}')
value = self.w_value(x)
return self.w_out(gate * value)
# Memory-efficient SwiGLU with fused operations
class MemoryEfficientSwiGLU(nn.Module):
def __init__(self, dim, hidden_dim, use_checkpoint=False):
super().__init__()
# Combine gate and value into single projection then split
self.w_in = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.w_out = nn.Linear(hidden_dim, dim, bias=False)
self.hidden_dim = hidden_dim
self.use_checkpoint = use_checkpoint
def forward(self, x):
if self.use_checkpoint and self.training:
return torch.utils.checkpoint.checkpoint(self._forward, x)
return self._forward(x)
def _forward(self, x):
# Single projection for both gate and value
x_proj = self.w_in(x)
gate, value = x_proj.split(self.hidden_dim, dim=-1)
# Apply Swish to gate and multiply
return self.w_out(F.silu(gate) * value)
# Integration with Transformer blocks
class TransformerBlockWithSwiGLU(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=2.67, dropout=0.0):
super().__init__()
self.norm1 = nn.RMSNorm(dim) if hasattr(nn, 'RMSNorm') else nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
self.norm2 = nn.RMSNorm(dim) if hasattr(nn, 'RMSNorm') else nn.LayerNorm(dim)
# SwiGLU FFN
hidden_dim = int(mlp_ratio * dim)
self.ffn = LlamaSwiGLU(dim, hidden_dim)
def forward(self, x, mask=None):
# Attention with pre-normalization
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0]
# FFN with SwiGLU
x = x + self.ffn(self.norm2(x))
return x
# TensorFlow/Keras implementation
import tensorflow as tf
class SwiGLUTF(tf.keras.layers.Layer):
def __init__(self, hidden_dim, **kwargs):
super().__init__(**kwargs)
self.hidden_dim = hidden_dim
def build(self, input_shape):
dim = input_shape[-1]
self.w_gate = self.add_weight(
shape=(dim, self.hidden_dim),
initializer='glorot_uniform',
trainable=True,
name='w_gate'
)
self.w_value = self.add_weight(
shape=(dim, self.hidden_dim),
initializer='glorot_uniform',
trainable=True,
name='w_value'
)
self.w_out = self.add_weight(
shape=(self.hidden_dim, dim),
initializer='glorot_uniform',
trainable=True,
name='w_out'
)
super().build(input_shape)
def call(self, inputs):
gate = tf.nn.silu(tf.matmul(inputs, self.w_gate))
value = tf.matmul(inputs, self.w_value)
hidden = gate * value
return tf.matmul(hidden, self.w_out)
# JAX/Flax implementation
try:
import jax
import jax.numpy as jnp
from flax import linen as nn
class SwiGLUFlax(nn.Module):
hidden_dim: int
@nn.compact
def __call__(self, x):
dim = x.shape[-1]
w_gate = self.param('w_gate', nn.initializers.glorot_uniform(), (dim, self.hidden_dim))
w_value = self.param('w_value', nn.initializers.glorot_uniform(), (dim, self.hidden_dim))
w_out = self.param('w_out', nn.initializers.glorot_uniform(), (self.hidden_dim, dim))
gate = jax.nn.silu(jnp.dot(x, w_gate))
value = jnp.dot(x, w_value)
hidden = gate * value
return jnp.dot(hidden, w_out)
print('JAX/Flax SwiGLU available')
except ImportError:
print('JAX not available')
# Test all implementations
print('
Testing SwiGLU variants...')
dim = 512
batch_size = 2
seq_len = 64
x = torch.randn(batch_size, seq_len, dim)
# Test different GLU variants
variants = ['swiglu', 'geglu', 'reglu', 'glu']
for variant in variants:
glu = SwiGLUVariant(variant, dim, int(dim * 2.67))
output = glu(x)
print(f'{variant.upper()}: output shape {output.shape}, mean {output.mean().item():.4f}')
# Memory efficient version
mem_eff = MemoryEfficientSwiGLU(dim, int(dim * 2.67))
output = mem_eff(x)
print(f'
Memory Efficient SwiGLU: output shape {output.shape}')
# Compare activations
swiglu_mod = LlamaSwiGLU(dim)
gelu_ffn = nn.Sequential(
nn.Linear(dim, int(dim * 4)),
nn.GELU(),
nn.Linear(int(dim * 4), dim)
)
with torch.no_grad():
out_swiglu = swiglu_mod(x)
out_gelu = gelu_ffn(x)
print(f'
Activation statistics:')
print(f'SwiGLU: mean={out_swiglu.mean():.4f}, std={out_swiglu.std():.4f}')
print(f'GELU: mean={out_gelu.mean():.4f}, std={out_gelu.std():.4f}')
When to Use
✅ Appropriate Use Cases:
- Large language models where GELU or ReLU underperform
- Transformer feed-forward layers seeking better expressiveness
- When training stability with deep networks is important
- Reproducing Llama, PaLM, or Mistral architectures
- Tasks where smooth gating behavior (vs hard ReLU) is beneficial
- When you want comparable performance to GELU with better gradient flow
❌ Avoid When:
- Very small models where activation choice matters less
- When parameter budget is tight (SwiGLU needs careful dimension tuning)
- Inference-only deployment where GELU is already optimized on hardware
- If you're fine-tuning a pretrained model that used a different activation
- When you need activations bounded in a specific range (SwiGLU is unbounded)
- Resource-constrained edge devices without SwiGLU kernel optimizations
Common Pitfalls
- Incorrect hidden dimension (should be ~2.67x input for SwiGLU vs 4x for GELU)
- Adding bias on value projection (modern SwiGLU typically uses bias=False)
- Forgetting that SwiGLU has 1.5x the FLOPs of standard FFN for same parameters
- Using SwiGLU with post-LN instead of pre-LN architecture
- Not accounting for increased memory usage from dual projections
- Mismatched activation (using ReLU or GELU instead of Swish/SiLU for gate)
- Initialization issues - gate and value projections need careful init
- Training instability if learning rate too high with SwiGLU (use warmup)