Unit Testing PyTorch Models with pytest: A Practical Guide

Unit Testing PyTorch Models with pytest: A Practical Guide

PyTorch models fail silently — a wrong tensor shape produces garbage outputs without errors. pytest-based unit tests catch shape mismatches, gradient issues, and training bugs early. This guide covers practical patterns for testing PyTorch models at every level.


Why PyTorch Models Need Unit Tests

Deep learning bugs are subtle:

  • Shape mismatches — a model that takes (batch, channels, height, width) but receives (batch, height, width, channels) often runs without error but produces wrong predictions
  • Gradient flow — a bug in the loss function can zero out gradients for specific layers, silently preventing learning
  • Numerical instability — NaN or Inf in weights or activations propagates silently
  • Training vs inference behavior — models behave differently in train() vs eval() mode (dropout, BatchNorm)

None of these produce exceptions by default. Without tests, you find out via degraded model performance — too late.


Setup

pip install torch pytest numpy

A standard PyTorch test file structure:

tests/
  models/
    test_cnn.py
    test_transformer.py
    test_loss_functions.py
  training/
    test_training_loop.py
    test_data_pipeline.py
conftest.py

Testing Model Architecture

Shape Tests

Shape tests are the most important tests to write first. They catch 80% of architectural bugs immediately:

import pytest
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu(self.fc1(x))
        return self.fc2(x)


@pytest.fixture
def model():
    return SimpleCNN(num_classes=10)


@pytest.fixture
def sample_batch():
    """Standard CIFAR-10 style batch: 4 samples, 3 channels, 32×32."""
    return torch.randn(4, 3, 32, 32)


def test_output_shape(model, sample_batch):
    """Output must be (batch_size, num_classes)."""
    output = model(sample_batch)
    assert output.shape == (4, 10), f"Expected (4, 10), got {output.shape}"


def test_batch_size_invariance(model):
    """Model must handle different batch sizes."""
    for batch_size in [1, 4, 16, 32]:
        x = torch.randn(batch_size, 3, 32, 32)
        output = model(x)
        assert output.shape == (batch_size, 10), \
            f"Failed for batch_size={batch_size}: got {output.shape}"


def test_output_has_no_nans(model, sample_batch):
    """Forward pass must produce finite values."""
    output = model(sample_batch)
    assert not torch.isnan(output).any(), "NaN in model output"
    assert not torch.isinf(output).any(), "Inf in model output"


def test_output_requires_no_grad_in_eval(model, sample_batch):
    """In eval mode, output should not retain gradients for inference efficiency."""
    model.eval()
    with torch.no_grad():
        output = model(sample_batch)
    assert not output.requires_grad

Testing Gradient Flow

Verify that gradients flow to all trainable parameters:

def test_gradients_flow_to_all_layers(model, sample_batch):
    """All parameters must receive gradients after a backward pass."""
    model.train()
    output = model(sample_batch)
    
    # Use a simple loss
    labels = torch.randint(0, 10, (4,))
    loss = nn.CrossEntropyLoss()(output, labels)
    loss.backward()
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            assert param.grad is not None, f"No gradient for parameter: {name}"
            assert not torch.isnan(param.grad).any(), f"NaN gradient for: {name}"
            assert not (param.grad == 0).all(), f"Zero gradient for: {name} (possible dead neuron)"


def test_specific_layer_receives_gradient(model, sample_batch):
    """Targeted gradient check for a specific layer."""
    model.train()
    output = model(sample_batch)
    labels = torch.randint(0, 10, (4,))
    loss = nn.CrossEntropyLoss()(output, labels)
    loss.backward()
    
    # Check the first conv layer specifically
    conv1_grad = model.conv1.weight.grad
    assert conv1_grad is not None
    assert conv1_grad.shape == model.conv1.weight.shape
    assert not torch.all(conv1_grad == 0)

Testing Train vs Eval Mode Behavior

Dropout and BatchNorm behave differently in train vs eval mode. Test both:

class ModelWithDropout(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.dropout = nn.Dropout(p=0.5)
    
    def forward(self, x):
        return self.dropout(self.linear(x))


def test_dropout_active_in_train_mode():
    """Dropout should make outputs non-deterministic in train mode."""
    model = ModelWithDropout()
    model.train()
    
    x = torch.ones(100, 10)  # Fixed input
    outputs = [model(x) for _ in range(10)]
    
    # With dropout, not all outputs should be identical
    all_same = all(torch.equal(outputs[0], o) for o in outputs[1:])
    assert not all_same, "Dropout should cause variation in train mode"


def test_dropout_disabled_in_eval_mode():
    """Dropout should be deterministic in eval mode."""
    model = ModelWithDropout()
    model.eval()
    
    x = torch.ones(100, 10)
    
    with torch.no_grad():
        output1 = model(x)
        output2 = model(x)
    
    assert torch.equal(output1, output2), "Eval mode outputs should be deterministic"

Testing Loss Functions

Custom loss functions are a common source of bugs:

def focal_loss(predictions, targets, gamma=2.0, alpha=0.25):
    """Focal loss for class imbalance."""
    bce = nn.BCEWithLogitsLoss(reduction='none')(predictions, targets.float())
    pt = torch.exp(-bce)
    focal = alpha * (1 - pt) ** gamma * bce
    return focal.mean()


def test_focal_loss_positive_value():
    preds = torch.randn(8, 1)
    targets = torch.randint(0, 2, (8, 1))
    loss = focal_loss(preds, targets)
    assert loss > 0, "Loss should be positive"
    assert not torch.isnan(loss), "Loss should not be NaN"


def test_focal_loss_perfect_predictions_is_low():
    """Perfect predictions should produce near-zero loss."""
    # Very confident correct predictions
    preds = torch.tensor([[10.0], [10.0], [-10.0], [-10.0]])
    targets = torch.tensor([[1], [1], [0], [0]])
    loss = focal_loss(preds, targets)
    assert loss < 0.01, f"Perfect prediction loss should be near 0, got {loss:.4f}"


def test_focal_loss_reduces_with_confidence():
    """Focal loss should penalize easy examples less than BCE."""
    easy_pred = torch.tensor([[5.0]])  # Confident correct
    hard_pred = torch.tensor([[0.1]])  # Uncertain
    target = torch.tensor([[1]])
    
    easy_loss = focal_loss(easy_pred, target)
    hard_loss = focal_loss(hard_pred, target)
    
    assert easy_loss < hard_loss, "Focal loss should weight hard examples more"

Testing the Training Loop

A one-step sanity check that training actually reduces loss:

def test_model_loss_decreases_over_training_steps():
    """Model loss must decrease over multiple gradient steps."""
    torch.manual_seed(42)
    model = SimpleCNN(num_classes=10)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    X = torch.randn(16, 3, 32, 32)
    y = torch.randint(0, 10, (16,))
    
    initial_loss = None
    losses = []
    
    model.train()
    for step in range(20):
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    
    # Loss should decrease over 20 steps
    assert losses[-1] < losses[0], \
        f"Loss did not decrease: {losses[0]:.4f}{losses[-1]:.4f}"


def test_model_can_overfit_small_dataset():
    """Model must be able to memorize a tiny dataset — validates model capacity."""
    torch.manual_seed(42)
    model = SimpleCNN(num_classes=10)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    
    # Tiny dataset — 8 samples
    X = torch.randn(8, 3, 32, 32)
    y = torch.randint(0, 10, (8,))
    
    model.train()
    for _ in range(100):
        optimizer.zero_grad()
        loss = criterion(model(X), y)
        loss.backward()
        optimizer.step()
    
    model.eval()
    with torch.no_grad():
        final_loss = criterion(model(X), y).item()
    
    assert final_loss < 0.1, \
        f"Model can't overfit tiny dataset (loss={final_loss:.4f}). Check architecture or learning rate."

Testing Device Compatibility

Ensure the model works on both CPU and CUDA:

@pytest.mark.parametrize("device", [
    "cpu",
    pytest.param("cuda", marks=pytest.mark.skipif(
        not torch.cuda.is_available(), reason="CUDA not available"
    ))
])
def test_model_runs_on_device(device):
    model = SimpleCNN().to(device)
    x = torch.randn(4, 3, 32, 32).to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(x)
    
    assert output.device.type == device
    assert output.shape == (4, 10)

CI Integration

name: PyTorch Model Tests
on: [push, pull_request]

jobs:
  test-cpu:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: '3.11'
      - run: pip install torch --index-url https://download.pytorch.org/whl/cpu
      - run: pip install pytest numpy
      - run: pytest tests/models/ tests/training/ -v --tb=short

Run PyTorch tests on CPU in CI (faster, cheaper). Reserve GPU tests for nightly runs.


Summary

PyTorch unit tests catch bugs that Python exceptions won't:

  • Shape tests first — wrong shapes propagate silently and produce wrong results
  • NaN/Inf checks — numerical instability is invisible without explicit assertions
  • Gradient flow tests — dead neurons and broken loss functions show up here
  • Train vs eval tests — dropout and BatchNorm bugs only appear in the wrong mode
  • Overfitting test — if a model can't memorize 8 samples, the architecture is broken

The overfit sanity check is particularly valuable: a model that can't fit tiny training data has a fundamental issue — wrong architecture, broken optimizer, or missing gradient flow — regardless of what the forward pass output looks like.

Read more