Model Drift Detection and Monitoring in Production
A model you deploy today will degrade over time. The world changes — user behavior shifts, economic conditions fluctuate, upstream data systems evolve. The model doesn't know. Without drift monitoring, you find out when a user complains or a business metric tanks.
Drift detection is the practice of monitoring model inputs and outputs for changes that signal the model needs retraining or investigation.
Types of Drift
Data drift (covariate shift): The distribution of input features changes. Example: a recommendation model trained on pre-pandemic browsing behavior sees different patterns post-pandemic.
Concept drift: The relationship between inputs and the target variable changes. Example: "good credit score" meant something different in 2008 than it does now.
Label drift (target shift): The distribution of the target variable changes. Example: the base rate of fraud changes due to economic conditions.
Prediction drift: The distribution of model outputs changes. This is the most observable signal — you can monitor it even when ground truth labels are delayed.
Monitoring Input Feature Distributions
Kolmogorov-Smirnov Test (KS Test)
The KS test measures whether two continuous distributions are statistically different. Use it to compare the feature distribution at training time vs. current serving time:
import numpy as np
from scipy import stats
import pandas as pd
def detect_numerical_drift(reference_data: np.ndarray,
current_data: np.ndarray,
feature_name: str,
alpha: float = 0.05) -> dict:
"""
Detect drift in a numerical feature using KS test.
Returns drift detection result with statistic and p-value.
"""
ks_stat, p_value = stats.ks_2samp(reference_data, current_data)
drift_detected = p_value < alpha
return {
"feature": feature_name,
"ks_statistic": round(ks_stat, 4),
"p_value": round(p_value, 6),
"drift_detected": drift_detected,
"reference_mean": np.mean(reference_data),
"current_mean": np.mean(current_data),
"reference_std": np.std(reference_data),
"current_std": np.std(current_data)
}
# Example usage
reference_df = pd.read_parquet("data/training_features.parquet")
current_df = pd.read_parquet("data/serving_features_last_7d.parquet")
for feature in ["purchase_amount", "session_duration", "page_views"]:
result = detect_numerical_drift(
reference_df[feature].dropna().values,
current_df[feature].dropna().values,
feature_name=feature
)
if result["drift_detected"]:
print(f"DRIFT: {feature} — KS={result['ks_statistic']:.4f}, p={result['p_value']:.6f}")
print(f" Mean: {result['reference_mean']:.2f} → {result['current_mean']:.2f}")Population Stability Index (PSI)
PSI is widely used in credit scoring and financial ML. It measures distributional shift on a continuous scale, which is more interpretable than a p-value:
def calculate_psi(reference: np.ndarray, current: np.ndarray, n_bins: int = 10) -> float:
"""
Population Stability Index.
Interpretation:
PSI < 0.10: No significant drift
0.10 ≤ PSI < 0.25: Moderate drift — investigate
PSI ≥ 0.25: Significant drift — likely need retraining
"""
# Create bins based on reference distribution
quantiles = np.quantile(reference, np.linspace(0, 1, n_bins + 1))
quantiles[0] = -np.inf
quantiles[-1] = np.inf
# Count samples in each bin
ref_counts, _ = np.histogram(reference, bins=quantiles)
cur_counts, _ = np.histogram(current, bins=quantiles)
# Normalize to proportions (add small epsilon to avoid log(0))
epsilon = 1e-8
ref_pct = (ref_counts / len(reference)) + epsilon
cur_pct = (cur_counts / len(current)) + epsilon
# PSI formula
psi = np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct))
return round(psi, 4)
# Track PSI over time
for date in weekly_dates:
weekly_df = load_serving_data(date)
psi = calculate_psi(reference_df["purchase_amount"].values,
weekly_df["purchase_amount"].values)
print(f"Week {date}: PSI={psi:.4f}", end=" ")
if psi < 0.10:
print("✓ Stable")
elif psi < 0.25:
print("⚠ Investigate")
else:
print("🚨 Retrain needed")Chi-Squared Test for Categorical Features
from scipy.stats import chi2_contingency
def detect_categorical_drift(reference: pd.Series,
current: pd.Series,
feature_name: str,
alpha: float = 0.05) -> dict:
"""Detect drift in categorical features using chi-squared test."""
# Get all categories from both datasets
all_categories = set(reference.unique()) | set(current.unique())
# Count occurrences in each dataset
ref_counts = reference.value_counts()
cur_counts = current.value_counts()
# Align to same categories (fill missing with 0)
ref_freq = np.array([ref_counts.get(cat, 0) for cat in all_categories])
cur_freq = np.array([cur_counts.get(cat, 0) for cat in all_categories])
# Chi-squared test
contingency_table = np.array([ref_freq, cur_freq])
chi2, p_value, dof, _ = chi2_contingency(contingency_table)
# Check for new categories
new_categories = set(current.unique()) - set(reference.unique())
return {
"feature": feature_name,
"chi2_statistic": round(chi2, 4),
"p_value": round(p_value, 6),
"drift_detected": p_value < alpha,
"new_categories": list(new_categories),
"new_category_count": len(new_categories)
}
# Example
result = detect_categorical_drift(
reference_df["payment_method"],
current_df["payment_method"],
"payment_method"
)
if result["new_categories"]:
print(f"New payment methods appearing: {result['new_categories']}")
# Model was never trained on these — predictions may be unreliableMonitoring Prediction Distributions
Even without ground truth labels, you can detect concept drift by monitoring how the model's outputs change over time:
import matplotlib.pyplot as plt
from collections import deque
class PredictionDriftMonitor:
def __init__(self, reference_predictions: np.ndarray, window_size: int = 1000):
self.reference = reference_predictions
self.window_size = window_size
self.current_window = deque(maxlen=window_size)
self.psi_history = []
def add_prediction(self, prediction: float):
self.current_window.append(prediction)
if len(self.current_window) == self.window_size:
psi = calculate_psi(
self.reference,
np.array(self.current_window)
)
self.psi_history.append(psi)
if psi >= 0.25:
self.alert(psi)
def alert(self, psi: float):
print(f"🚨 Prediction drift alert: PSI={psi:.4f}")
# Send to monitoring system, PagerDuty, Slack, etc.
# Usage
monitor = PredictionDriftMonitor(reference_predictions=train_predictions)
for batch in serving_batches:
predictions = model.predict(batch.features)
for pred in predictions:
monitor.add_prediction(pred)Evidently: Production-Grade Drift Monitoring
Evidently AI is an open-source library that automates drift reporting:
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset, TargetDriftPreset
report = Report(metrics=[
DataDriftPreset(drift_share=0.3), # Alert if >30% of features drift
TargetDriftPreset(),
])
reference_data = pd.read_parquet("data/training_features.parquet")
current_data = pd.read_parquet("data/serving_features_last_7d.parquet")
# Include target if you have delayed labels
reference_data["churn"] = training_labels
current_data["churn"] = delayed_labels # May be null for recent rows
report.run(reference_data=reference_data, current_data=current_data)
report.save_html("drift_report.html")
# Programmatic access to results
drift_results = report.as_dict()
drifted_features = [
result["feature_name"]
for result in drift_results["metrics"][0]["result"]["drift_by_columns"].values()
if result["drift_detected"]
]
if len(drifted_features) > 5:
trigger_retraining_pipeline()Setting Up Automated Drift Alerts
# drift_monitor.py — run daily via cron or Airflow
import boto3
import json
from datetime import datetime, timedelta
def daily_drift_check():
reference_df = load_training_features()
yesterday_df = load_serving_features(
start=datetime.now() - timedelta(days=1),
end=datetime.now()
)
drift_report = {}
alerts = []
for feature in NUMERICAL_FEATURES:
psi = calculate_psi(reference_df[feature].values,
yesterday_df[feature].values)
drift_report[feature] = {"psi": psi}
if psi >= 0.25:
alerts.append(f"{feature}: PSI={psi:.4f} (CRITICAL)")
elif psi >= 0.10:
alerts.append(f"{feature}: PSI={psi:.4f} (WARNING)")
# Store results
s3 = boto3.client("s3")
s3.put_object(
Bucket="ml-monitoring",
Key=f"drift/{datetime.now().strftime('%Y/%m/%d')}/report.json",
Body=json.dumps(drift_report)
)
# Alert if needed
if any("CRITICAL" in alert for alert in alerts):
send_pagerduty_alert(f"Model drift detected:\n" + "\n".join(alerts))
elif alerts:
send_slack_notification(f"Drift warnings:\n" + "\n".join(alerts))
if __name__ == "__main__":
daily_drift_check()When to Retrain
Drift detection tells you something changed. Retraining is not always the answer:
| Drift Type | Likely Cause | Response |
|---|---|---|
| Sudden input drift | Upstream data bug | Investigate data pipeline first |
| Gradual input drift | Natural world change | Schedule retraining |
| New categories | Product change | Add to training data, retrain |
| Prediction distribution shift | Concept drift | Retrain with recent labeled data |
| Performance drop (if labeled) | Any drift type | Retrain immediately |
Building drift detection into your ML pipeline is not optional for production models. The questions are when you build it (before or after the first production incident) and how sophisticated your response automation becomes over time.