Model Drift Detection and Monitoring in Production

Model Drift Detection and Monitoring in Production

A model you deploy today will degrade over time. The world changes — user behavior shifts, economic conditions fluctuate, upstream data systems evolve. The model doesn't know. Without drift monitoring, you find out when a user complains or a business metric tanks.

Drift detection is the practice of monitoring model inputs and outputs for changes that signal the model needs retraining or investigation.

Types of Drift

Data drift (covariate shift): The distribution of input features changes. Example: a recommendation model trained on pre-pandemic browsing behavior sees different patterns post-pandemic.

Concept drift: The relationship between inputs and the target variable changes. Example: "good credit score" meant something different in 2008 than it does now.

Label drift (target shift): The distribution of the target variable changes. Example: the base rate of fraud changes due to economic conditions.

Prediction drift: The distribution of model outputs changes. This is the most observable signal — you can monitor it even when ground truth labels are delayed.

Monitoring Input Feature Distributions

Kolmogorov-Smirnov Test (KS Test)

The KS test measures whether two continuous distributions are statistically different. Use it to compare the feature distribution at training time vs. current serving time:

import numpy as np
from scipy import stats
import pandas as pd

def detect_numerical_drift(reference_data: np.ndarray, 
                           current_data: np.ndarray,
                           feature_name: str,
                           alpha: float = 0.05) -> dict:
    """
    Detect drift in a numerical feature using KS test.
    
    Returns drift detection result with statistic and p-value.
    """
    ks_stat, p_value = stats.ks_2samp(reference_data, current_data)
    
    drift_detected = p_value < alpha
    
    return {
        "feature": feature_name,
        "ks_statistic": round(ks_stat, 4),
        "p_value": round(p_value, 6),
        "drift_detected": drift_detected,
        "reference_mean": np.mean(reference_data),
        "current_mean": np.mean(current_data),
        "reference_std": np.std(reference_data),
        "current_std": np.std(current_data)
    }

# Example usage
reference_df = pd.read_parquet("data/training_features.parquet")
current_df = pd.read_parquet("data/serving_features_last_7d.parquet")

for feature in ["purchase_amount", "session_duration", "page_views"]:
    result = detect_numerical_drift(
        reference_df[feature].dropna().values,
        current_df[feature].dropna().values,
        feature_name=feature
    )
    
    if result["drift_detected"]:
        print(f"DRIFT: {feature} — KS={result['ks_statistic']:.4f}, p={result['p_value']:.6f}")
        print(f"  Mean: {result['reference_mean']:.2f}{result['current_mean']:.2f}")

Population Stability Index (PSI)

PSI is widely used in credit scoring and financial ML. It measures distributional shift on a continuous scale, which is more interpretable than a p-value:

def calculate_psi(reference: np.ndarray, current: np.ndarray, n_bins: int = 10) -> float:
    """
    Population Stability Index.
    
    Interpretation:
    PSI < 0.10: No significant drift
    0.10 ≤ PSI < 0.25: Moderate drift — investigate
    PSI ≥ 0.25: Significant drift — likely need retraining
    """
    # Create bins based on reference distribution
    quantiles = np.quantile(reference, np.linspace(0, 1, n_bins + 1))
    quantiles[0] = -np.inf
    quantiles[-1] = np.inf
    
    # Count samples in each bin
    ref_counts, _ = np.histogram(reference, bins=quantiles)
    cur_counts, _ = np.histogram(current, bins=quantiles)
    
    # Normalize to proportions (add small epsilon to avoid log(0))
    epsilon = 1e-8
    ref_pct = (ref_counts / len(reference)) + epsilon
    cur_pct = (cur_counts / len(current)) + epsilon
    
    # PSI formula
    psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
    return round(psi, 4)

# Track PSI over time
for date in weekly_dates:
    weekly_df = load_serving_data(date)
    psi = calculate_psi(reference_df["purchase_amount"].values,
                        weekly_df["purchase_amount"].values)
    
    print(f"Week {date}: PSI={psi:.4f}", end=" ")
    if psi < 0.10:
        print("✓ Stable")
    elif psi < 0.25:
        print("⚠ Investigate")
    else:
        print("🚨 Retrain needed")

Chi-Squared Test for Categorical Features

from scipy.stats import chi2_contingency

def detect_categorical_drift(reference: pd.Series, 
                              current: pd.Series,
                              feature_name: str,
                              alpha: float = 0.05) -> dict:
    """Detect drift in categorical features using chi-squared test."""
    # Get all categories from both datasets
    all_categories = set(reference.unique()) | set(current.unique())
    
    # Count occurrences in each dataset
    ref_counts = reference.value_counts()
    cur_counts = current.value_counts()
    
    # Align to same categories (fill missing with 0)
    ref_freq = np.array([ref_counts.get(cat, 0) for cat in all_categories])
    cur_freq = np.array([cur_counts.get(cat, 0) for cat in all_categories])
    
    # Chi-squared test
    contingency_table = np.array([ref_freq, cur_freq])
    chi2, p_value, dof, _ = chi2_contingency(contingency_table)
    
    # Check for new categories
    new_categories = set(current.unique()) - set(reference.unique())
    
    return {
        "feature": feature_name,
        "chi2_statistic": round(chi2, 4),
        "p_value": round(p_value, 6),
        "drift_detected": p_value < alpha,
        "new_categories": list(new_categories),
        "new_category_count": len(new_categories)
    }

# Example
result = detect_categorical_drift(
    reference_df["payment_method"],
    current_df["payment_method"],
    "payment_method"
)

if result["new_categories"]:
    print(f"New payment methods appearing: {result['new_categories']}")
    # Model was never trained on these — predictions may be unreliable

Monitoring Prediction Distributions

Even without ground truth labels, you can detect concept drift by monitoring how the model's outputs change over time:

import matplotlib.pyplot as plt
from collections import deque

class PredictionDriftMonitor:
    def __init__(self, reference_predictions: np.ndarray, window_size: int = 1000):
        self.reference = reference_predictions
        self.window_size = window_size
        self.current_window = deque(maxlen=window_size)
        self.psi_history = []
    
    def add_prediction(self, prediction: float):
        self.current_window.append(prediction)
        
        if len(self.current_window) == self.window_size:
            psi = calculate_psi(
                self.reference,
                np.array(self.current_window)
            )
            self.psi_history.append(psi)
            
            if psi >= 0.25:
                self.alert(psi)
    
    def alert(self, psi: float):
        print(f"🚨 Prediction drift alert: PSI={psi:.4f}")
        # Send to monitoring system, PagerDuty, Slack, etc.

# Usage
monitor = PredictionDriftMonitor(reference_predictions=train_predictions)

for batch in serving_batches:
    predictions = model.predict(batch.features)
    for pred in predictions:
        monitor.add_prediction(pred)

Evidently: Production-Grade Drift Monitoring

Evidently AI is an open-source library that automates drift reporting:

from evidently.report import Report
from evidently.metric_preset import DataDriftPreset, TargetDriftPreset

report = Report(metrics=[
    DataDriftPreset(drift_share=0.3),   # Alert if >30% of features drift
    TargetDriftPreset(),
])

reference_data = pd.read_parquet("data/training_features.parquet")
current_data = pd.read_parquet("data/serving_features_last_7d.parquet")

# Include target if you have delayed labels
reference_data["churn"] = training_labels
current_data["churn"] = delayed_labels  # May be null for recent rows

report.run(reference_data=reference_data, current_data=current_data)
report.save_html("drift_report.html")

# Programmatic access to results
drift_results = report.as_dict()
drifted_features = [
    result["feature_name"] 
    for result in drift_results["metrics"][0]["result"]["drift_by_columns"].values()
    if result["drift_detected"]
]

if len(drifted_features) > 5:
    trigger_retraining_pipeline()

Setting Up Automated Drift Alerts

# drift_monitor.py — run daily via cron or Airflow
import boto3
import json
from datetime import datetime, timedelta

def daily_drift_check():
    reference_df = load_training_features()
    yesterday_df = load_serving_features(
        start=datetime.now() - timedelta(days=1),
        end=datetime.now()
    )
    
    drift_report = {}
    alerts = []
    
    for feature in NUMERICAL_FEATURES:
        psi = calculate_psi(reference_df[feature].values, 
                           yesterday_df[feature].values)
        drift_report[feature] = {"psi": psi}
        
        if psi >= 0.25:
            alerts.append(f"{feature}: PSI={psi:.4f} (CRITICAL)")
        elif psi >= 0.10:
            alerts.append(f"{feature}: PSI={psi:.4f} (WARNING)")
    
    # Store results
    s3 = boto3.client("s3")
    s3.put_object(
        Bucket="ml-monitoring",
        Key=f"drift/{datetime.now().strftime('%Y/%m/%d')}/report.json",
        Body=json.dumps(drift_report)
    )
    
    # Alert if needed
    if any("CRITICAL" in alert for alert in alerts):
        send_pagerduty_alert(f"Model drift detected:\n" + "\n".join(alerts))
    elif alerts:
        send_slack_notification(f"Drift warnings:\n" + "\n".join(alerts))

if __name__ == "__main__":
    daily_drift_check()

When to Retrain

Drift detection tells you something changed. Retraining is not always the answer:

Drift Type Likely Cause Response
Sudden input drift Upstream data bug Investigate data pipeline first
Gradual input drift Natural world change Schedule retraining
New categories Product change Add to training data, retrain
Prediction distribution shift Concept drift Retrain with recent labeled data
Performance drop (if labeled) Any drift type Retrain immediately

Building drift detection into your ML pipeline is not optional for production models. The questions are when you build it (before or after the first production incident) and how sophisticated your response automation becomes over time.

Read more