SwiGLU: Swish-Gated Linear Unit

Advanced Deep Learning
~11 min read Deep Learning

Definition

SwiGLU (Swish-Gated Linear Unit) is a modern activation function introduced by Shazeer in the GLU Variants paper (2020) and popularized by the PaLM and Llama models. It combines the Swish (SiLU) activation with a gating mechanism inspired by LSTM gates. SwiGLU is defined as Swish(xW) ⊙ (xV), where two different linear projections are applied to the input, one passes through a Swish activation and acts as a gate for the other. This gating mechanism allows the network to learn which parts of the input to pass through and which to suppress. SwiGLU has been shown to consistently outperform traditional activations like ReLU and GELU in large language models, with PaLM reporting improved training stability and final performance. The activation is particularly effective in the feed-forward layers of Transformers.

Intuition

💡

Imagine a security checkpoint with two guards. Guard A (the Swish gate) decides how much to open the gate based on what they see. Guard B (the linear projection) presents the actual items to pass through. The final output is what Guard B presents, scaled by how much Guard A decides to open the gate. If Guard A sees something important, they open the gate fully (multiply by ~1); if they see noise, they close it (multiply by ~0). This is SwiGLU: the Swish activation acts as a learned gate that selectively passes information from the linear projection. Unlike ReLU which is either on or off (like a binary gate), Swish provides smooth gating - it can be partially open, allowing for more nuanced information flow. The two projections (W and V) learn different transformations - one focuses on 'what to gate' and the other on 'what to pass.' This dual-path architecture gives the network more expressive power than simple element-wise activations.

Mathematical Formula

Swish Activation:
\[ \text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}} \]
GLU (Gated Linear Unit):
\[ \text{GLU}(x) = \sigma(xW + b) \odot (xV + c) \]
SwiGLU:
\[ \text{SwiGLU}(x) = \text{Swish}(xW) \odot (xV) \]
In Feed-Forward Network:
\[ \text{FFN}_{SwiGLU}(x) = (\text{Swish}(xW_1) \odot xW_2)W_3 \]
Derivative of Swish:
\[ \frac{d}{dx} \text{Swish}(x) = \sigma(x) + x \cdot \sigma(x) \cdot (1 - \sigma(x)) \]
Gating Mechanism:
\[ g = \sigma(W_g x + b_g), \quad \tilde{x} = W_f x + b_f, \quad y = g \odot \tilde{x} \]

Step-by-Step Explanation:

  1. Swish: Smooth activation that is x · sigmoid(x), providing non-linearity with unbounded positive range
  2. GLU: Gating mechanism using sigmoid gate to modulate linear projection
  3. SwiGLU: Combines Swish gating with linear projection; note no bias on second projection in modern variants
  4. FFN SwiGLU: Transformer feed-forward using SwiGLU as activation (replaces ReLU/GELU)
  5. Swish Derivative: Combines sigmoid contribution with input-scaled contribution for smooth gradients
  6. General Gating: Learned gate g multiplies filtered input x̃, allowing selective information flow

Real-World Use Cases

Large Language Models

PaLM (540B parameters) using SwiGLU in feed-forward layers for improved training

Open Source LLMs

Llama 2 and Llama 3 adopting SwiGLU as the default activation

Mixture of Experts

Mixtral using SwiGLU in expert layers for better gating

Vision Transformers

EVA-CLIP and SigLIP using SwiGLU for vision-language pretraining

Code Generation

CodeLlama leveraging SwiGLU for programming language understanding

Multilingual Models

BLOOM and multilingual PaLM using SwiGLU for cross-lingual transfer

Implementation

Manual Implementation (No Libraries)

The SwiGLU implementation shows the gating mechanism with two linear projections. The gate uses Swish/SiLU activation while the value path remains linear. FeedForwardSwiGLU adapts this for Transformer FFN layers with appropriate hidden dimension scaling (8/3 ratio to maintain parameter count). The NumPy version demonstrates backpropagation through the gating mechanism.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class Swish(nn.Module):
    
    
    def __init__(self, beta=1.0):
        super().__init__()
        self.beta = beta
    
    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

class SwiGLU(nn.Module):
    
    
    def __init__(self, dim_input, dim_hidden, bias=True):
        super().__init__()
        
        # Two linear projections: one for gating, one for value
        self.W_gate = nn.Linear(dim_input, dim_hidden, bias=bias)
        self.W_value = nn.Linear(dim_input, dim_hidden, bias=False)  # No bias on value in modern implementations
        
    def forward(self, x):
        
        # Gate pathway with Swish activation
        gate = self.W_gate(x)
        gate = F.silu(gate)  # SiLU is equivalent to Swish with beta=1
        
        # Value pathway (linear)
        value = self.W_value(x)
        
        # Gating: element-wise multiplication
        output = gate * value
        
        return output

class FeedForwardSwiGLU(nn.Module):
    
    
    def __init__(self, dim, hidden_dim=None, multiple_of=256, dropout=0.0):
        super().__init__()
        
        # Hidden dim is typically 2.66x input dim in SwiGLU (to maintain parameter count)
        # Standard: hidden_dim = 4 * dim for GELU, ~8/3 * dim for SwiGLU
        if hidden_dim is None:
            hidden_dim = 4 * dim
            # Adjust to multiple_of for hardware efficiency
            hidden_dim = int(2 * hidden_dim / 3)  # SwiGLU uses 2/3 for comparable params
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # Gate projection
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)  # Value projection
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)  # Output projection
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        
        # SwiGLU: Swish(w1(x)) * w2(x)
        gate = F.silu(self.w1(x))  # SiLU/Swish activation
        hidden = gate * self.w2(x)  # Gating
        output = self.w3(hidden)
        output = self.dropout(output)
        
        return output

# NumPy implementation for understanding gradients
class SwiGLUNumPy:
    
    
    def __init__(self, input_dim, hidden_dim):
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Xavier initialization
        self.W_gate = np.random.randn(input_dim, hidden_dim) * np.sqrt(2.0 / input_dim)
        self.W_value = np.random.randn(input_dim, hidden_dim) * np.sqrt(2.0 / input_dim)
        self.b_gate = np.zeros(hidden_dim)
    
    def swish(self, x):
        return x * (1 / (1 + np.exp(-np.clip(x, -500, 500))))
    
    def swish_derivative(self, x):
        sigmoid = 1 / (1 + np.exp(-np.clip(x, -500, 500)))
        return sigmoid + x * sigmoid * (1 - sigmoid)
    
    def forward(self, x):
        
        # Linear projections
        z_gate = np.dot(x, self.W_gate) + self.b_gate
        z_value = np.dot(x, self.W_value)
        
        # Swish activation on gate
        gate = self.swish(z_gate)
        
        # Gating
        output = gate * z_value
        
        cache = (x, z_gate, z_value, gate)
        return output, cache
    
    def backward(self, grad_output, cache, learning_rate=0.01):
        
        x, z_gate, z_value, gate = cache
        
        # Gradient through gating (element-wise multiplication)
        grad_gate = grad_output * z_value
        grad_z_value = grad_output * gate
        
        # Gradient through Swish
        grad_z_gate = grad_gate * self.swish_derivative(z_gate)
        
        # Gradients w.r.t. weights
        grad_W_gate = np.dot(x.T, grad_z_gate)
        grad_b_gate = np.sum(grad_z_gate, axis=0)
        grad_W_value = np.dot(x.T, grad_z_value)
        
        # Update weights
        self.W_gate -= learning_rate * grad_W_gate
        self.b_gate -= learning_rate * grad_b_gate
        self.W_value -= learning_rate * grad_W_value
        
        # Gradient w.r.t. input for backprop
        grad_x = np.dot(grad_z_gate, self.W_gate.T) + np.dot(grad_z_value, self.W_value.T)
        
        return grad_x

# Visualization and comparison
def visualize_activations():
    x = torch.linspace(-5, 5, 1000)
    
    # Different activations
    relu = F.relu(x)
    gelu = F.gelu(x)
    swish = x * torch.sigmoid(x)
    sigmoid = torch.sigmoid(x)
    
    # Plotting
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Activations
    ax = axes[0, 0]
    ax.plot(x.numpy(), relu.numpy(), label='ReLU', linewidth=2)
    ax.plot(x.numpy(), gelu.numpy(), label='GELU', linewidth=2)
    ax.plot(x.numpy(), swish.numpy(), label='Swish', linewidth=2)
    ax.set_title('Activation Functions')
    ax.legend()
    ax.grid(True)
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    ax.axvline(x=0, color='k', linestyle='--', alpha=0.3)
    
    # Derivatives
    ax = axes[0, 1]
    x_grad = x[:-1] + 0.5 * (x[1] - x[:-1])
    relu_grad = torch.diff(relu) / torch.diff(x)
    gelu_grad = torch.diff(gelu) / torch.diff(x)
    swish_grad = torch.diff(swish) / torch.diff(x)
    
    ax.plot(x_grad.numpy(), relu_grad.numpy(), label='ReLU'', linewidth=2)
    ax.plot(x_grad.numpy(), gelu_grad.numpy(), label='GELU'', linewidth=2)
    ax.plot(x_grad.numpy(), swish_grad.numpy(), label='Swish'', linewidth=2)
    ax.set_title('Derivatives')
    ax.legend()
    ax.grid(True)
    ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    ax.axvline(x=0, color='k', linestyle='--', alpha=0.3)
    
    # SwiGLU behavior visualization
    ax = axes[1, 0]
    swiglu = SwiGLU(1, 1)
    x_2d = torch.linspace(-3, 3, 100).unsqueeze(1)
    with torch.no_grad():
        # Manual SwiGLU computation
        gate = torch.sigmoid(x_2d) * x_2d  # Swish
        value = x_2d  # Identity for visualization
        output = gate * value
    
    ax.plot(x_2d.numpy(), x_2d.numpy(), label='Input', alpha=0.5)
    ax.plot(x_2d.numpy(), gate.numpy(), label='Gate (Swish)', alpha=0.7)
    ax.plot(x_2d.numpy(), output.numpy(), label='SwiGLU Output', linewidth=2)
    ax.set_title('SwiGLU Gating Behavior')
    ax.legend()
    ax.grid(True)
    
    # Smoothness comparison
    ax = axes[1, 1]
    ax.plot(x.numpy(), np.abs(np.diff(relu.numpy(), prepend=relu[0].numpy())), label='ReLU |Δ|', alpha=0.7)
    ax.plot(x.numpy(), np.abs(np.diff(gelu.numpy(), prepend=gelu[0].numpy())), label='GELU |Δ|', alpha=0.7)
    ax.plot(x.numpy(), np.abs(np.diff(swish.numpy(), prepend=swish[0].numpy())), label='Swish |Δ|', alpha=0.7)
    ax.set_title('Smoothness (Change Magnitude)')
    ax.legend()
    ax.grid(True)
    
    plt.tight_layout()
    plt.savefig('swiglu_visualization.png', dpi=150)
    print('Visualization saved to swiglu_visualization.png')

# Test implementations
print('Testing SwiGLU implementations...')

batch_size = 4
seq_len = 128
dim = 512
x = torch.randn(batch_size, seq_len, dim)

# Test SwiGLU
swiglu = SwiGLU(dim, dim * 2)
output = swiglu(x)
print(f'SwiGLU input shape: {x.shape}')
print(f'SwiGLU output shape: {output.shape}')

# Test FeedForward SwiGLU
ffn = FeedForwardSwiGLU(dim)
output_ffn = ffn(x)
print(f'
FeedForward SwiGLU input shape: {x.shape}')
print(f'FeedForward SwiGLU output shape: {output_ffn.shape}')

# NumPy version
swiglu_np = SwiGLUNumPy(dim, dim * 2)
x_np = x[0, 0].numpy().reshape(1, -1)
output_np, cache = swiglu_np.forward(x_np)
print(f'
NumPy SwiGLU output shape: {output_np.shape}')

# Gradient test
grad_output = np.ones_like(output_np)
grad_x = swiglu_np.backward(grad_output, cache)
print(f'Gradient shape: {grad_x.shape}')

# Parameter count comparison
ffn_relu = nn.Sequential(
    nn.Linear(dim, 4 * dim),
    nn.ReLU(),
    nn.Linear(4 * dim, dim)
)

ffn_gelu = nn.Sequential(
    nn.Linear(dim, 4 * dim),
    nn.GELU(),
    nn.Linear(4 * dim, dim)
)

ffn_swiglu = FeedForwardSwiGLU(dim, hidden_dim=int(8 * dim / 3))

print(f'
Parameter comparison:')
print(f'FFN ReLU: {sum(p.numel() for p in ffn_relu.parameters())/1e6:.2f}M')
print(f'FFN GELU: {sum(p.numel() for p in ffn_gelu.parameters())/1e6:.2f}M')
print(f'FFN SwiGLU: {sum(p.numel() for p in ffn_swiglu.parameters())/1e6:.2f}M')

Using Libraries (torch, torch.nn, transformers, tensorflow, jax, flax)

import torch
import torch.nn as nn
import torch.nn.functional as F

# Standard PyTorch SwiGLU (used in Llama, PaLM)
class LlamaSwiGLU(nn.Module):
    
    
    def __init__(self, dim, hidden_dim=None, multiple_of=256, dropout=0.0):
        super().__init__()
        
        # Llama uses 2/3 of 4*dim for SwiGLU to match parameter count with standard FFN
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # Gate projection
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # Down projection
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # Value projection
        
    def forward(self, x):
        
        # Llama FFN: w2(silu(w1(x)) * w3(x))
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

# Hugging Face Transformers integration
try:
    from transformers.activations import ACT2FN
    from transformers.models.llama.modeling_llama import LlamaMLP
    
    # Llama MLP with SwiGLU
    llama_mlp = LlamaMLP(config=None, hidden_size=512, intermediate_size=1376)
    print('HuggingFace Llama MLP loaded')
    
except ImportError:
    print('HuggingFace transformers not available')

# Custom SwiGLU variants
class SwiGLUVariant(nn.Module):
    
    
    def __init__(self, variant='swiglu', dim=512, hidden_dim=1376):
        super().__init__()
        self.variant = variant
        
        self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.w_value = nn.Linear(dim, hidden_dim, bias=False)
        self.w_out = nn.Linear(hidden_dim, dim, bias=False)
        
    def forward(self, x):
        
        if self.variant == 'swiglu':
            # Swish(xW) * xV
            gate = F.silu(self.w_gate(x))
        elif self.variant == 'geglu':
            # GELU(xW) * xV
            gate = F.gelu(self.w_gate(x))
        elif self.variant == 'reglu':
            # ReLU(xW) * xV
            gate = F.relu(self.w_gate(x))
        elif self.variant == 'glu':
            # sigmoid(xW) * xV
            gate = torch.sigmoid(self.w_gate(x))
        else:
            raise ValueError(f'Unknown variant: {self.variant}')
        
        value = self.w_value(x)
        return self.w_out(gate * value)

# Memory-efficient SwiGLU with fused operations
class MemoryEfficientSwiGLU(nn.Module):
    
    
    def __init__(self, dim, hidden_dim, use_checkpoint=False):
        super().__init__()
        
        # Combine gate and value into single projection then split
        self.w_in = nn.Linear(dim, 2 * hidden_dim, bias=False)
        self.w_out = nn.Linear(hidden_dim, dim, bias=False)
        self.hidden_dim = hidden_dim
        self.use_checkpoint = use_checkpoint
    
    def forward(self, x):
        
        if self.use_checkpoint and self.training:
            return torch.utils.checkpoint.checkpoint(self._forward, x)
        return self._forward(x)
    
    def _forward(self, x):
        
        # Single projection for both gate and value
        x_proj = self.w_in(x)
        gate, value = x_proj.split(self.hidden_dim, dim=-1)
        
        # Apply Swish to gate and multiply
        return self.w_out(F.silu(gate) * value)

# Integration with Transformer blocks
class TransformerBlockWithSwiGLU(nn.Module):
    
    
    def __init__(self, dim, num_heads, mlp_ratio=2.67, dropout=0.0):
        super().__init__()
        
        self.norm1 = nn.RMSNorm(dim) if hasattr(nn, 'RMSNorm') else nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.RMSNorm(dim) if hasattr(nn, 'RMSNorm') else nn.LayerNorm(dim)
        
        # SwiGLU FFN
        hidden_dim = int(mlp_ratio * dim)
        self.ffn = LlamaSwiGLU(dim, hidden_dim)
        
    def forward(self, x, mask=None):
        
        # Attention with pre-normalization
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0]
        
        # FFN with SwiGLU
        x = x + self.ffn(self.norm2(x))
        
        return x

# TensorFlow/Keras implementation
import tensorflow as tf

class SwiGLUTF(tf.keras.layers.Layer):
    
    
    def __init__(self, hidden_dim, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
    
    def build(self, input_shape):
        dim = input_shape[-1]
        self.w_gate = self.add_weight(
            shape=(dim, self.hidden_dim),
            initializer='glorot_uniform',
            trainable=True,
            name='w_gate'
        )
        self.w_value = self.add_weight(
            shape=(dim, self.hidden_dim),
            initializer='glorot_uniform',
            trainable=True,
            name='w_value'
        )
        self.w_out = self.add_weight(
            shape=(self.hidden_dim, dim),
            initializer='glorot_uniform',
            trainable=True,
            name='w_out'
        )
        super().build(input_shape)
    
    def call(self, inputs):
        gate = tf.nn.silu(tf.matmul(inputs, self.w_gate))
        value = tf.matmul(inputs, self.w_value)
        hidden = gate * value
        return tf.matmul(hidden, self.w_out)

# JAX/Flax implementation
try:
    import jax
    import jax.numpy as jnp
    from flax import linen as nn
    
    class SwiGLUFlax(nn.Module):
        hidden_dim: int
        
        @nn.compact
        def __call__(self, x):
            dim = x.shape[-1]
            
            w_gate = self.param('w_gate', nn.initializers.glorot_uniform(), (dim, self.hidden_dim))
            w_value = self.param('w_value', nn.initializers.glorot_uniform(), (dim, self.hidden_dim))
            w_out = self.param('w_out', nn.initializers.glorot_uniform(), (self.hidden_dim, dim))
            
            gate = jax.nn.silu(jnp.dot(x, w_gate))
            value = jnp.dot(x, w_value)
            hidden = gate * value
            return jnp.dot(hidden, w_out)
    
    print('JAX/Flax SwiGLU available')
    
except ImportError:
    print('JAX not available')

# Test all implementations
print('
Testing SwiGLU variants...')
dim = 512
batch_size = 2
seq_len = 64
x = torch.randn(batch_size, seq_len, dim)

# Test different GLU variants
variants = ['swiglu', 'geglu', 'reglu', 'glu']
for variant in variants:
    glu = SwiGLUVariant(variant, dim, int(dim * 2.67))
    output = glu(x)
    print(f'{variant.upper()}: output shape {output.shape}, mean {output.mean().item():.4f}')

# Memory efficient version
mem_eff = MemoryEfficientSwiGLU(dim, int(dim * 2.67))
output = mem_eff(x)
print(f'
Memory Efficient SwiGLU: output shape {output.shape}')

# Compare activations
swiglu_mod = LlamaSwiGLU(dim)
gelu_ffn = nn.Sequential(
    nn.Linear(dim, int(dim * 4)),
    nn.GELU(),
    nn.Linear(int(dim * 4), dim)
)

with torch.no_grad():
    out_swiglu = swiglu_mod(x)
    out_gelu = gelu_ffn(x)
    
print(f'
Activation statistics:')
print(f'SwiGLU: mean={out_swiglu.mean():.4f}, std={out_swiglu.std():.4f}')
print(f'GELU: mean={out_gelu.mean():.4f}, std={out_gelu.std():.4f}')

When to Use

✅ Appropriate Use Cases:

  • Large language models where GELU or ReLU underperform
  • Transformer feed-forward layers seeking better expressiveness
  • When training stability with deep networks is important
  • Reproducing Llama, PaLM, or Mistral architectures
  • Tasks where smooth gating behavior (vs hard ReLU) is beneficial
  • When you want comparable performance to GELU with better gradient flow

❌ Avoid When:

  • Very small models where activation choice matters less
  • When parameter budget is tight (SwiGLU needs careful dimension tuning)
  • Inference-only deployment where GELU is already optimized on hardware
  • If you're fine-tuning a pretrained model that used a different activation
  • When you need activations bounded in a specific range (SwiGLU is unbounded)
  • Resource-constrained edge devices without SwiGLU kernel optimizations

Common Pitfalls

  • Incorrect hidden dimension (should be ~2.67x input for SwiGLU vs 4x for GELU)
  • Adding bias on value projection (modern SwiGLU typically uses bias=False)
  • Forgetting that SwiGLU has 1.5x the FLOPs of standard FFN for same parameters
  • Using SwiGLU with post-LN instead of pre-LN architecture
  • Not accounting for increased memory usage from dual projections
  • Mismatched activation (using ReLU or GELU instead of Swish/SiLU for gate)
  • Initialization issues - gate and value projections need careful init
  • Training instability if learning rate too high with SwiGLU (use warmup)