Testing ML Experiments with Weights & Biases: Assertions, Alerts, and CI Checks
Weights & Biases (W&B) tracks ML experiments — but it also provides the infrastructure to assert training quality, detect regressions, and gate deployments on metric thresholds. This post covers how to use W&B as a testing tool, not just a logging tool.
W&B as a Testing Tool
Beyond logging, W&B enables:
- Metric assertions — fail a run if accuracy drops below a threshold
- Automated alerts — notify on Slack/email when metrics cross boundaries
- Model registry gates — require passing eval metrics before promoting to production
- CI integration — use the W&B API to query past runs and compare
Basic Run with Assertions
import wandb
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
# Thresholds — treat these as your test spec
ACCURACY_THRESHOLD = 0.90
F1_THRESHOLD = 0.88
run = wandb.init(
project="iris-classifier",
config={
"n_estimators": 100,
"max_depth": 5,
"random_state": 42,
}
)
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = RandomForestClassifier(**wandb.config)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
accuracy = accuracy_score(y_test, preds)
f1 = f1_score(y_test, preds, average="weighted")
wandb.log({"accuracy": accuracy, "f1": f1})
# Assertions — fail the run if quality gates are not met
assert accuracy >= ACCURACY_THRESHOLD, (
f"Accuracy {accuracy:.3f} below threshold {ACCURACY_THRESHOLD}"
)
assert f1 >= F1_THRESHOLD, (
f"F1 {f1:.3f} below threshold {F1_THRESHOLD}"
)
print(f"All checks passed: accuracy={accuracy:.3f}, f1={f1:.3f}")
wandb.finish()Querying Previous Runs for Regression Detection
Use the W&B API to compare a new run against historical baselines:
import wandb
api = wandb.Api()
def get_best_accuracy(project: str, metric: str = "accuracy") -> float:
"""Return the best metric value from all completed runs."""
runs = api.runs(
project,
filters={"state": "finished"},
order=f"-summary_metrics.{metric}",
)
if not runs:
return 0.0
return runs[0].summary.get(metric, 0.0)
def assert_no_regression(
current_accuracy: float,
project: str,
tolerance: float = 0.02,
):
"""Fail if the current run is worse than the best historical run by more than tolerance."""
best = get_best_accuracy(project)
min_acceptable = best - tolerance
assert current_accuracy >= min_acceptable, (
f"Regression detected: current={current_accuracy:.3f}, "
f"best={best:.3f}, tolerance={tolerance}"
)
# In your training script:
assert_no_regression(accuracy, project="my-team/iris-classifier")pytest Integration
Write W&B checks as pytest tests for CI:
# tests/test_training_quality.py
import pytest
import wandb
PROJECT = "my-team/iris-classifier"
MIN_ACCURACY = 0.90
MIN_F1 = 0.88
MAX_REGRESSION = 0.02 # allow 2% drop from historical best
@pytest.fixture(scope="session")
def api():
return wandb.Api()
@pytest.fixture(scope="session")
def latest_run(api):
runs = api.runs(
PROJECT,
filters={"state": "finished"},
order="-created_at",
)
assert len(runs) > 0, "No completed runs found"
return runs[0]
def test_accuracy_above_threshold(latest_run):
accuracy = latest_run.summary.get("accuracy")
assert accuracy is not None, "No accuracy metric in latest run"
assert accuracy >= MIN_ACCURACY, (
f"Accuracy {accuracy:.3f} < threshold {MIN_ACCURACY}"
)
def test_f1_above_threshold(latest_run):
f1 = latest_run.summary.get("f1")
assert f1 is not None, "No f1 metric in latest run"
assert f1 >= MIN_F1
def test_no_accuracy_regression(api, latest_run):
all_runs = api.runs(
PROJECT,
filters={"state": "finished"},
order="-summary_metrics.accuracy",
)
best_accuracy = all_runs[0].summary.get("accuracy", 0)
current_accuracy = latest_run.summary.get("accuracy", 0)
assert current_accuracy >= best_accuracy - MAX_REGRESSION, (
f"Regression: current={current_accuracy:.3f}, best={best_accuracy:.3f}"
)
def test_training_time_acceptable(latest_run):
duration = latest_run.summary.get("_wandb", {}).get("runtime", 0)
assert duration < 3600, f"Training took {duration}s, max is 3600s"Automated Alerts
Configure alerts in W&B to notify on threshold violations:
import wandb
run = wandb.init(project="production-model")
# Log metrics during training
for epoch in range(100):
loss = train_one_epoch()
val_loss = evaluate()
wandb.log({"loss": loss, "val_loss": val_loss, "epoch": epoch})
# Programmatic alert if diverging
if val_loss > loss * 1.5:
wandb.alert(
title="Validation loss diverging",
text=f"val_loss={val_loss:.4f} is 50% above train_loss={loss:.4f} at epoch {epoch}",
level=wandb.AlertLevel.WARN,
)Configure alert destinations (Slack, email) in W&B settings under Alerts.
Model Registry Gates
Use the W&B Model Registry to enforce quality gates before promoting a model:
import wandb
api = wandb.Api()
def promote_if_quality_passes(
artifact_name: str,
version: str,
project: str,
accuracy_threshold: float = 0.92,
):
artifact = api.artifact(f"{project}/{artifact_name}:{version}")
run = artifact.logged_by()
accuracy = run.summary.get("accuracy", 0)
if accuracy >= accuracy_threshold:
# Link to the "production" alias
artifact.aliases.append("production")
artifact.save()
print(f"Promoted {artifact_name}:{version} to production (accuracy={accuracy:.3f})")
else:
raise ValueError(
f"Model not promoted: accuracy={accuracy:.3f} < {accuracy_threshold}"
)
promote_if_quality_passes(
"iris-classifier",
"v3",
"my-team/iris-classifier",
)CI Pipeline
# .github/workflows/ml-quality-gate.yml
name: ML Quality Gate
on:
push:
paths:
- "train.py"
- "model/**"
jobs:
train-and-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- run: pip install wandb scikit-learn pytest
- name: Train model
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
run: python train.py
- name: Run quality gate tests
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
run: pytest tests/test_training_quality.py -vSweep Testing (Hyperparameter Validation)
Test that your hyperparameter sweep produces at least one configuration above the quality bar:
import wandb
api = wandb.Api()
def test_sweep_produced_valid_model(sweep_id: str, min_accuracy: float = 0.90):
sweep = api.sweep(sweep_id)
best_run = sweep.best_run(order="accuracy")
accuracy = best_run.summary.get("accuracy", 0)
assert accuracy >= min_accuracy, (
f"Best sweep run only achieved accuracy={accuracy:.3f}, "
f"minimum required is {min_accuracy}"
)
test_sweep_produced_valid_model("my-team/iris-classifier/sweep-abc123")Key Takeaways
- Add explicit
assertstatements in training scripts to fail runs on quality regressions - Use the W&B API in pytest to query historical runs and detect metric degradation
- Configure
wandb.alert()for runtime notifications when training goes wrong - Gate model promotion with the Model Registry using alias-based quality checks
- Run W&B-backed pytest tests in CI after every model training job