Testing ML Models with MLflow: Tracking, Validation, and Registry

Testing ML Models with MLflow: Tracking, Validation, and Registry

MLflow provides the infrastructure for systematic ML model testing: experiment tracking for comparison, model validation for quality gates, and a model registry for staging promotion. This guide covers all three with practical test patterns for CI/ML pipelines.


Why ML Models Need Tests

Software tests assert that code does what it's supposed to do. ML model tests assert that models perform at acceptable levels. Without them:

  • A retrained model silently degrades on edge cases
  • A preprocessing change introduces a bug that shows up only at inference time
  • A model passes accuracy metrics but fails fairness requirements
  • A new model version is deployed that's 3x slower than the previous one

MLflow doesn't write your model tests, but it provides the infrastructure to make them systematic and reproducible.


MLflow Basics: What You're Testing

MLflow has four components:

  1. Tracking — logs parameters, metrics, and artifacts for each training run
  2. Projects — packages code for reproducible runs
  3. Models — standard format for saving and loading models
  4. Registry — version control and staging for production models

Each component offers hooks for testing.


Setting Up MLflow for Testing

pip install mlflow pytest scikit-learn numpy pandas

# Start a local MLflow server (or use MLflow Tracking Server)
mlflow server --backend-store-uri sqlite:///mlflow.db --port 5000
import mlflow
import os

# Point to your tracking server
mlflow.set_tracking_uri(os.environ.get('MLFLOW_TRACKING_URI', 'http://localhost:5000'))

Testing Training Reproducibility

A model that trains differently on identical data is unreliable. Test for determinism:

import pytest
import mlflow
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification

def train_model(seed=42):
    X, y = make_classification(n_samples=1000, random_state=seed)
    
    with mlflow.start_run() as run:
        clf = RandomForestClassifier(n_estimators=100, random_state=seed)
        clf.fit(X, y)
        
        accuracy = clf.score(X, y)
        mlflow.log_metric("train_accuracy", accuracy)
        mlflow.log_param("n_estimators", 100)
        mlflow.log_param("seed", seed)
        mlflow.sklearn.log_model(clf, "model")
        
    return run.info.run_id, accuracy


def test_training_is_deterministic():
    run_id_1, acc_1 = train_model(seed=42)
    run_id_2, acc_2 = train_model(seed=42)
    
    # Same seed → same accuracy
    assert acc_1 == acc_2, f"Non-deterministic training: {acc_1} vs {acc_2}"


def test_different_seeds_produce_different_results():
    _, acc_42 = train_model(seed=42)
    _, acc_99 = train_model(seed=99)
    
    # Different seeds usually produce different results (not always, but for RF it should)
    # This is a sanity check that randomness is actually being used
    assert acc_42 != acc_99 or True  # May be equal by coincidence

Testing Model Quality Gates

Define minimum acceptable performance thresholds:

from sklearn.model_selection import cross_val_score
from sklearn.datasets import make_classification
import numpy as np

ACCURACY_THRESHOLD = 0.85
F1_THRESHOLD = 0.82
LATENCY_THRESHOLD_MS = 10  # Per-sample prediction latency


def test_model_meets_accuracy_threshold():
    """Model must achieve 85%+ cross-validated accuracy on the test dataset."""
    X, y = make_classification(n_samples=2000, n_features=20, random_state=42)
    
    # Load the candidate model from MLflow
    model_uri = f"models:/MyModel/Staging"  # From registry
    model = mlflow.sklearn.load_model(model_uri)
    
    scores = cross_val_score(model, X, y, cv=5, scoring='accuracy')
    mean_accuracy = np.mean(scores)
    
    assert mean_accuracy >= ACCURACY_THRESHOLD, \
        f"Model accuracy {mean_accuracy:.3f} below threshold {ACCURACY_THRESHOLD}"


def test_model_prediction_latency():
    """Single-sample prediction must complete in under 10ms."""
    import time
    
    model_uri = f"models:/MyModel/Staging"
    model = mlflow.pyfunc.load_model(model_uri)
    
    import pandas as pd
    sample = pd.DataFrame([[1.0] * 20], columns=[f"feature_{i}" for i in range(20)])
    
    latencies = []
    for _ in range(100):
        start = time.perf_counter()
        model.predict(sample)
        latencies.append((time.perf_counter() - start) * 1000)
    
    p99_latency = np.percentile(latencies, 99)
    assert p99_latency < LATENCY_THRESHOLD_MS, \
        f"P99 latency {p99_latency:.1f}ms exceeds threshold {LATENCY_THRESHOLD_MS}ms"


def test_model_handles_edge_case_inputs():
    """Model must not error on boundary inputs."""
    model_uri = f"models:/MyModel/Staging"
    model = mlflow.pyfunc.load_model(model_uri)
    
    import pandas as pd
    edge_cases = pd.DataFrame([
        [0.0] * 20,                        # All zeros
        [1e6] * 20,                        # Very large values
        [-1e6] * 20,                       # Very negative values
        [float('nan')] + [1.0] * 19,      # Missing value (if model handles nulls)
    ], columns=[f"feature_{i}" for i in range(20)])
    
    # Should not raise any exception
    try:
        predictions = model.predict(edge_cases.fillna(0))
        assert len(predictions) == 4
    except Exception as e:
        pytest.fail(f"Model failed on edge case input: {e}")

Testing the MLflow Model Registry

The Registry manages model lifecycle: None → Staging → Production → Archived.

from mlflow.tracking import MlflowClient

client = MlflowClient()


def test_staging_model_exists_before_promotion():
    """There must be a Staging model before promoting to Production."""
    versions = client.search_model_versions("name='MyModel'")
    staging_versions = [v for v in versions if v.current_stage == "Staging"]
    
    assert len(staging_versions) > 0, \
        "No model in Staging stage — cannot promote to Production"


def test_staging_model_has_required_tags():
    """Staging models must have accuracy and dataset tags set."""
    versions = client.search_model_versions("name='MyModel' and current_stage='Staging'")
    latest_staging = sorted(versions, key=lambda v: int(v.version))[-1]
    
    required_tags = ["accuracy", "training_dataset_version", "feature_schema_version"]
    
    for tag in required_tags:
        assert tag in latest_staging.tags, \
            f"Required tag '{tag}' missing from Staging model v{latest_staging.version}"


def test_model_version_has_complete_run_data():
    """Every registered model must link to a complete training run."""
    versions = client.search_model_versions("name='MyModel' and current_stage='Staging'")
    
    for version in versions:
        run_id = version.run_id
        run = client.get_run(run_id)
        
        # Run must have logged the required metrics
        assert "val_accuracy" in run.data.metrics, \
            f"Model v{version.version} run missing val_accuracy metric"
        assert "val_f1" in run.data.metrics, \
            f"Model v{version.version} run missing val_f1 metric"
        
        # Run must have logged the model artifact
        artifacts = [a.path for a in client.list_artifacts(run_id)]
        assert any("model" in a for a in artifacts), \
            f"Model v{version.version} run has no model artifact"

Automated Promotion Pipeline

Gate Production promotion on automated test results:

def promote_to_production_if_tests_pass(model_name: str, staging_version: str):
    """
    Run all quality gates and promote only if they all pass.
    Use this in your CI/ML pipeline.
    """
    client = MlflowClient()
    
    results = {}
    
    # Gate 1: Accuracy
    try:
        test_model_meets_accuracy_threshold()
        results["accuracy"] = "PASS"
    except AssertionError as e:
        results["accuracy"] = f"FAIL: {e}"
    
    # Gate 2: Latency
    try:
        test_model_prediction_latency()
        results["latency"] = "PASS"
    except AssertionError as e:
        results["latency"] = f"FAIL: {e}"
    
    # Gate 3: Required metadata
    try:
        test_staging_model_has_required_tags()
        results["metadata"] = "PASS"
    except AssertionError as e:
        results["metadata"] = f"FAIL: {e}"
    
    all_passed = all(v == "PASS" for v in results.values())
    
    if all_passed:
        client.transition_model_version_stage(
            name=model_name,
            version=staging_version,
            stage="Production",
            archive_existing_versions=True
        )
        print(f"✓ Model {model_name} v{staging_version} promoted to Production")
    else:
        failed = {k: v for k, v in results.items() if v != "PASS"}
        raise RuntimeError(f"Promotion blocked. Failed gates: {failed}")
    
    return results

CI/CD Integration

name: ML Model Quality Gates
on:
  push:
    paths:
      - 'models/**'
      - 'training/**'

jobs:
  model-quality:
    runs-on: ubuntu-latest
    env:
      MLFLOW_TRACKING_URI: ${{ secrets.MLFLOW_TRACKING_URI }}
      MLFLOW_TRACKING_TOKEN: ${{ secrets.MLFLOW_TRACKING_TOKEN }}
    
    steps:
      - uses: actions/checkout@v4
      
      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.11'
      
      - name: Install dependencies
        run: pip install mlflow scikit-learn pytest numpy pandas
      
      - name: Train model and register to Staging
        run: python training/train.py --register-to-staging
      
      - name: Run model quality gates
        run: pytest tests/model/ -v --tb=short
      
      - name: Promote to Production (if tests pass)
        if: github.ref == 'refs/heads/main'
        run: python scripts/promote_model.py

Monitoring Models in Production

MLflow tracks model versions but not production behavior. Add real-time monitoring:

# Monitor your model serving endpoint
helpmetest health ml-model-api 5m

<span class="hljs-comment"># For batch prediction jobs, use health checks to verify completion
helpmetest health ml-batch-predictions 1h

Combined with MLflow experiment tracking, this gives you a complete picture: historical performance in MLflow, current serving health in HelpMeTest.


Summary

MLflow testing covers the full model lifecycle:

  • Tracking — log all training runs for comparison; test that metrics are being logged
  • Validation — enforce accuracy, latency, and edge case thresholds before promotion
  • Registry — gate stage transitions on automated quality checks
  • CI integration — run quality gates on every training run, promote only on green

The key shift is treating model performance requirements as executable tests rather than documentation. If it's not in a test, it will eventually drift.

Read more