Stochastic Gradient Descent

Intermediate Optimization
~8 min read Optimization
Prerequisites:

Definition

Stochastic Gradient Descent (SGD) is an optimization algorithm that approximates the true gradient of the loss function using only a single randomly selected training example (or a small minibatch) at each iteration. Unlike batch gradient descent which computes the exact gradient over the entire dataset, SGD uses a noisy but computationally cheap estimate. This noise, surprisingly, becomes a feature rather than a bug—it helps the optimizer escape shallow local minima and saddle points that might trap batch gradient descent. SGD is the workhorse of modern deep learning, enabling training on massive datasets that would be computationally prohibitive with batch methods. The term 'stochastic' refers to the random sampling of training examples, which injects randomness into the optimization process and provides inherent regularization effects.

Intuition

💡

Picture hiking down a mountain in thick fog where you can only see the ground immediately around your feet. Batch gradient descent would require you to carefully survey the entire landscape before each step—accurate but incredibly slow. SGD is like feeling your way down using only the slope directly beneath you. Sometimes this local slope differs from the overall direction, causing you to zigzag, but you move much faster and can sometimes stumble through small barriers that would block a careful surveyor. Minibatch SGD strikes a balance—you survey a small patch of ground (16-512 samples) to get a better sense of direction while still moving quickly. The noise in SGD acts like random vibrations that shake you loose from small potholes (poor local minima) while the general downward trend carries you toward the valley.

Mathematical Formula

\[ \theta_{t+1} = \theta_t - \eta abla_\theta L(\theta_t; x^{(i)}, y^{(i)}) \]

Step-by-Step Explanation:

  1. Step 1: Randomly shuffle the training dataset at the beginning of each epoch
  2. Step 2: Sample a single training example \(x^{(i)}, y^{(i)}\) or minibatch uniformly at random
  3. Step 3: Compute the loss for just this sample: \(L(\theta_t; x^{(i)}, y^{(i)})\)
  4. Step 4: Calculate the gradient of this sample loss with respect to parameters
  5. Step 5: Update parameters using this stochastic gradient estimate
  6. Step 6: Repeat until convergence, cycling through the entire dataset (one epoch) multiple times

Real-World Use Cases

Large-Scale Image Classification

Training ResNet or Vision Transformer on ImageNet (1.28M images) - SGD enables processing images incrementally without loading entire dataset into memory, making it feasible to train on commodity hardware.

Natural Language Processing

Training BERT or GPT-style models on billions of tokens from web text corpora. SGD with large batches processes text chunks sequentially, enabling language model pretraining.

Online Learning Systems

Real-time recommendation systems that update models as users interact. SGD allows single-example updates when a user clicks an item, enabling immediate model adaptation.

Streaming Data Processing

Fraud detection systems processing millions of transactions per day. SGD updates models on each transaction as it arrives without storing historical data.

Mobile and Edge Devices

On-device personalization where models update based on user behavior. SGD's low memory footprint enables training directly on resource-constrained devices.

Implementation

Manual Implementation (No Libraries)

The implementation supports both pure SGD (batch_size=1) and minibatch SGD. It shuffles data each epoch to ensure randomness, then processes samples in minibatches. The gradient is computed only on the current batch, providing a noisy but fast approximation. This reduces per-step computation from O(n) to O(batch_size).
import numpy as np

def stochastic_gradient_descent(X, y, learning_rate=0.01, n_epochs=100, 
                                  batch_size=1, shuffle=True):
    """
    Stochastic Gradient Descent with minibatch support.
    
    Args:
        X: Feature matrix (n_samples, n_features)
        y: Target vector (n_samples,)
        learning_rate: Step size for parameter updates
        n_epochs: Number of passes through the entire dataset
        batch_size: Number of samples per gradient computation (1 for pure SGD)
        shuffle: Whether to shuffle data each epoch
    
    Returns:
        theta: Optimized parameters
        loss_history: Average loss per epoch
    """
    n_samples, n_features = X.shape
    theta = np.random.randn(n_features) * 0.01
    loss_history = []
    
    for epoch in range(n_epochs):
        # Shuffle data at the beginning of each epoch
        if shuffle:
            indices = np.random.permutation(n_samples)
            X_shuffled = X[indices]
            y_shuffled = y[indices]
        else:
            X_shuffled, y_shuffled = X, y
        
        epoch_loss = 0
        n_batches = 0
        
        # Iterate through minibatches
        for i in range(0, n_samples, batch_size):
            # Get minibatch
            X_batch = X_shuffled[i:i+batch_size]
            y_batch = y_shuffled[i:i+batch_size]
            
            # Forward pass
            y_pred = X_batch @ theta
            
            # Compute batch loss
            batch_loss = np.mean((y_pred - y_batch) ** 2)
            epoch_loss += batch_loss
            n_batches += 1
            
            # Compute gradient for this batch
            gradient = (2 / len(X_batch)) * X_batch.T @ (y_pred - y_batch)
            
            # Update parameters
            theta = theta - learning_rate * gradient
        
        # Record average loss for this epoch
        avg_loss = epoch_loss / n_batches
        loss_history.append(avg_loss)
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Loss = {avg_loss:.6f}')
    
    return theta, loss_history

# Example with different batch sizes
np.random.seed(42)
X = np.random.randn(1000, 5)
true_theta = np.array([1.5, -2.0, 0.5, 3.0, -1.0])
y = X @ true_theta + np.random.randn(1000) * 0.5

# Pure SGD (batch_size=1)
print('=== Pure SGD (batch_size=1) ===')
theta_sgd, losses_sgd = stochastic_gradient_descent(
    X, y, learning_rate=0.01, n_epochs=50, batch_size=1
)

# Minibatch SGD (batch_size=32)
print('
=== Minibatch SGD (batch_size=32) ===')
theta_mbsgd, losses_mbsgd = stochastic_gradient_descent(
    X, y, learning_rate=0.01, n_epochs=50, batch_size=32
)

Using Libraries (torch.optim.SGD, torch.utils.data.DataLoader, tensorflow.keras.optimizers.SGD)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# Create dataset and dataloader
X = torch.randn(1000, 5)
y = torch.randn(1000, 1)
dataset = TensorDataset(X, y)

# SGD with batch_size=1 (pure SGD)
dataloader_sgd = DataLoader(dataset, batch_size=1, shuffle=True)

model = nn.Linear(5, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

# Training loop
for epoch in range(50):
    epoch_loss = 0
    for batch_X, batch_y in dataloader_sgd:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: Loss = {epoch_loss/len(dataloader_sgd):.6f}')

# Minibatch SGD with DataLoader
print('
=== Minibatch SGD ===')
dataloader_mb = DataLoader(dataset, batch_size=32, shuffle=True)
model_mb = nn.Linear(5, 1)
optimizer_mb = torch.optim.SGD(model_mb.parameters(), lr=0.01)

for epoch in range(50):
    epoch_loss = 0
    for batch_X, batch_y in dataloader_mb:
        optimizer_mb.zero_grad()
        outputs = model_mb(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer_mb.step()
        epoch_loss += loss.item()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}: Loss = {epoch_loss/len(dataloader_mb):.6f}')

# TensorFlow/Keras implementation
import tensorflow as tf

model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(5,))])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
              loss='mse')

# Using fit with batch_size for SGD
history = model.fit(X.numpy(), y.numpy(), epochs=50, 
                    batch_size=32, verbose=0)
print(f'TensorFlow final loss: {history.history["loss"][-1]:.6f}')

When to Use

✅ Appropriate Use Cases:

  • When training on large datasets that don't fit in memory
  • For online learning scenarios with streaming data
  • When computational efficiency is more important than exact convergence
  • To escape shallow local minima and saddle points in non-convex landscapes
  • For deep learning where batch methods are computationally prohibitive
  • When built-in regularization from noise is beneficial
  • For real-time systems requiring immediate model updates
  • When training on distributed systems (easy to parallelize)
  • For recommendation systems with sparse, high-dimensional data

❌ Avoid When:

  • When precise convergence to a specific minimum is required
  • For small datasets where batch gradient descent is fast enough
  • When loss function evaluations are extremely expensive (use batch methods)
  • For convex problems where exact gradient is easily computable
  • When training requires stable, reproducible paths to solution
  • When the noise would destabilize training (use larger batches or adaptive methods)

Common Pitfalls

  • {'pitfall': 'High variance in gradients', 'description': 'Single-sample gradients have high variance causing noisy updates and oscillations around the minimum. The loss curve looks erratic.', 'solution': 'Increase batch size, use learning rate scheduling with decay, or add momentum to smooth out updates.'}
  • {'pitfall': 'Slow convergence in flat regions', 'description': 'In regions with small gradients, SGD moves very slowly and may appear stuck even when far from the minimum.', 'solution': 'Use adaptive learning rates (Adam, RMSprop) or learning rate scheduling with warmup.'}
  • {'pitfall': 'Improper learning rate scaling', 'description': 'Learning rates that work for small batches may diverge with large batches, and vice versa.', 'solution': 'Linear learning rate scaling: when increasing batch size by factor k, scale learning rate by sqrt(k) or k.'}
  • {'pitfall': 'Non-representative batches', 'description': 'Small batches may not represent the true data distribution, especially with imbalanced classes.', 'solution': 'Stratified sampling, ensure proper class balance in batches, or increase batch size.'}
  • {'pitfall': 'Noisy convergence detection', 'description': 'High variance makes it hard to determine when training has converged.', 'solution': 'Use moving averages of loss, track validation metrics, or implement early stopping with patience.'}