Unit Testing PyTorch Models with pytest: A Practical Guide
PyTorch models fail silently — a wrong tensor shape produces garbage outputs without errors. pytest-based unit tests catch shape mismatches, gradient issues, and training bugs early. This guide covers practical patterns for testing PyTorch models at every level.
Why PyTorch Models Need Unit Tests
Deep learning bugs are subtle:
- Shape mismatches — a model that takes
(batch, channels, height, width)but receives(batch, height, width, channels)often runs without error but produces wrong predictions - Gradient flow — a bug in the loss function can zero out gradients for specific layers, silently preventing learning
- Numerical instability — NaN or Inf in weights or activations propagates silently
- Training vs inference behavior — models behave differently in
train()vseval()mode (dropout, BatchNorm)
None of these produce exceptions by default. Without tests, you find out via degraded model performance — too late.
Setup
pip install torch pytest numpyA standard PyTorch test file structure:
tests/
models/
test_cnn.py
test_transformer.py
test_loss_functions.py
training/
test_training_loop.py
test_data_pipeline.py
conftest.pyTesting Model Architecture
Shape Tests
Shape tests are the most important tests to write first. They catch 80% of architectural bugs immediately:
import pytest
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 256)
self.fc2 = nn.Linear(256, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = self.relu(self.fc1(x))
return self.fc2(x)
@pytest.fixture
def model():
return SimpleCNN(num_classes=10)
@pytest.fixture
def sample_batch():
"""Standard CIFAR-10 style batch: 4 samples, 3 channels, 32×32."""
return torch.randn(4, 3, 32, 32)
def test_output_shape(model, sample_batch):
"""Output must be (batch_size, num_classes)."""
output = model(sample_batch)
assert output.shape == (4, 10), f"Expected (4, 10), got {output.shape}"
def test_batch_size_invariance(model):
"""Model must handle different batch sizes."""
for batch_size in [1, 4, 16, 32]:
x = torch.randn(batch_size, 3, 32, 32)
output = model(x)
assert output.shape == (batch_size, 10), \
f"Failed for batch_size={batch_size}: got {output.shape}"
def test_output_has_no_nans(model, sample_batch):
"""Forward pass must produce finite values."""
output = model(sample_batch)
assert not torch.isnan(output).any(), "NaN in model output"
assert not torch.isinf(output).any(), "Inf in model output"
def test_output_requires_no_grad_in_eval(model, sample_batch):
"""In eval mode, output should not retain gradients for inference efficiency."""
model.eval()
with torch.no_grad():
output = model(sample_batch)
assert not output.requires_gradTesting Gradient Flow
Verify that gradients flow to all trainable parameters:
def test_gradients_flow_to_all_layers(model, sample_batch):
"""All parameters must receive gradients after a backward pass."""
model.train()
output = model(sample_batch)
# Use a simple loss
labels = torch.randint(0, 10, (4,))
loss = nn.CrossEntropyLoss()(output, labels)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"No gradient for parameter: {name}"
assert not torch.isnan(param.grad).any(), f"NaN gradient for: {name}"
assert not (param.grad == 0).all(), f"Zero gradient for: {name} (possible dead neuron)"
def test_specific_layer_receives_gradient(model, sample_batch):
"""Targeted gradient check for a specific layer."""
model.train()
output = model(sample_batch)
labels = torch.randint(0, 10, (4,))
loss = nn.CrossEntropyLoss()(output, labels)
loss.backward()
# Check the first conv layer specifically
conv1_grad = model.conv1.weight.grad
assert conv1_grad is not None
assert conv1_grad.shape == model.conv1.weight.shape
assert not torch.all(conv1_grad == 0)Testing Train vs Eval Mode Behavior
Dropout and BatchNorm behave differently in train vs eval mode. Test both:
class ModelWithDropout(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 5)
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
return self.dropout(self.linear(x))
def test_dropout_active_in_train_mode():
"""Dropout should make outputs non-deterministic in train mode."""
model = ModelWithDropout()
model.train()
x = torch.ones(100, 10) # Fixed input
outputs = [model(x) for _ in range(10)]
# With dropout, not all outputs should be identical
all_same = all(torch.equal(outputs[0], o) for o in outputs[1:])
assert not all_same, "Dropout should cause variation in train mode"
def test_dropout_disabled_in_eval_mode():
"""Dropout should be deterministic in eval mode."""
model = ModelWithDropout()
model.eval()
x = torch.ones(100, 10)
with torch.no_grad():
output1 = model(x)
output2 = model(x)
assert torch.equal(output1, output2), "Eval mode outputs should be deterministic"Testing Loss Functions
Custom loss functions are a common source of bugs:
def focal_loss(predictions, targets, gamma=2.0, alpha=0.25):
"""Focal loss for class imbalance."""
bce = nn.BCEWithLogitsLoss(reduction='none')(predictions, targets.float())
pt = torch.exp(-bce)
focal = alpha * (1 - pt) ** gamma * bce
return focal.mean()
def test_focal_loss_positive_value():
preds = torch.randn(8, 1)
targets = torch.randint(0, 2, (8, 1))
loss = focal_loss(preds, targets)
assert loss > 0, "Loss should be positive"
assert not torch.isnan(loss), "Loss should not be NaN"
def test_focal_loss_perfect_predictions_is_low():
"""Perfect predictions should produce near-zero loss."""
# Very confident correct predictions
preds = torch.tensor([[10.0], [10.0], [-10.0], [-10.0]])
targets = torch.tensor([[1], [1], [0], [0]])
loss = focal_loss(preds, targets)
assert loss < 0.01, f"Perfect prediction loss should be near 0, got {loss:.4f}"
def test_focal_loss_reduces_with_confidence():
"""Focal loss should penalize easy examples less than BCE."""
easy_pred = torch.tensor([[5.0]]) # Confident correct
hard_pred = torch.tensor([[0.1]]) # Uncertain
target = torch.tensor([[1]])
easy_loss = focal_loss(easy_pred, target)
hard_loss = focal_loss(hard_pred, target)
assert easy_loss < hard_loss, "Focal loss should weight hard examples more"Testing the Training Loop
A one-step sanity check that training actually reduces loss:
def test_model_loss_decreases_over_training_steps():
"""Model loss must decrease over multiple gradient steps."""
torch.manual_seed(42)
model = SimpleCNN(num_classes=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
X = torch.randn(16, 3, 32, 32)
y = torch.randint(0, 10, (16,))
initial_loss = None
losses = []
model.train()
for step in range(20):
optimizer.zero_grad()
output = model(X)
loss = criterion(output, y)
loss.backward()
optimizer.step()
losses.append(loss.item())
# Loss should decrease over 20 steps
assert losses[-1] < losses[0], \
f"Loss did not decrease: {losses[0]:.4f} → {losses[-1]:.4f}"
def test_model_can_overfit_small_dataset():
"""Model must be able to memorize a tiny dataset — validates model capacity."""
torch.manual_seed(42)
model = SimpleCNN(num_classes=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
# Tiny dataset — 8 samples
X = torch.randn(8, 3, 32, 32)
y = torch.randint(0, 10, (8,))
model.train()
for _ in range(100):
optimizer.zero_grad()
loss = criterion(model(X), y)
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
final_loss = criterion(model(X), y).item()
assert final_loss < 0.1, \
f"Model can't overfit tiny dataset (loss={final_loss:.4f}). Check architecture or learning rate."Testing Device Compatibility
Ensure the model works on both CPU and CUDA:
@pytest.mark.parametrize("device", [
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
))
])
def test_model_runs_on_device(device):
model = SimpleCNN().to(device)
x = torch.randn(4, 3, 32, 32).to(device)
model.eval()
with torch.no_grad():
output = model(x)
assert output.device.type == device
assert output.shape == (4, 10)CI Integration
name: PyTorch Model Tests
on: [push, pull_request]
jobs:
test-cpu:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- run: pip install torch --index-url https://download.pytorch.org/whl/cpu
- run: pip install pytest numpy
- run: pytest tests/models/ tests/training/ -v --tb=shortRun PyTorch tests on CPU in CI (faster, cheaper). Reserve GPU tests for nightly runs.
Summary
PyTorch unit tests catch bugs that Python exceptions won't:
- Shape tests first — wrong shapes propagate silently and produce wrong results
- NaN/Inf checks — numerical instability is invisible without explicit assertions
- Gradient flow tests — dead neurons and broken loss functions show up here
- Train vs eval tests — dropout and BatchNorm bugs only appear in the wrong mode
- Overfitting test — if a model can't memorize 8 samples, the architecture is broken
The overfit sanity check is particularly valuable: a model that can't fit tiny training data has a fundamental issue — wrong architecture, broken optimizer, or missing gradient flow — regardless of what the forward pass output looks like.