TensorFlow Model Validation with TFMA: Testing ML Models in Production Pipelines
TensorFlow Model Analysis (TFMA) provides slice-based model evaluation — testing not just overall accuracy but performance across demographic and feature subgroups. This guide covers TFMA setup, fairness testing, model blessing, and integration into TFX pipelines.
What Is TensorFlow Model Analysis?
TFMA is part of the TensorFlow Extended (TFX) ecosystem, designed to answer questions that overall accuracy metrics hide:
- Does the model perform equally well across user age groups?
- Does accuracy drop for a specific geographic region?
- How does model quality compare to the currently-deployed version?
TFMA evaluates models on large datasets in a distributed fashion, computing metrics on arbitrary "slices" — subsets of data defined by feature values.
Installation
pip install tensorflow-model-analysis
pip install apache-beam[gcp] # For distributed evaluation
pip install tensorflow tfxBasic TFMA Evaluation
Model Setup
TFMA works with SavedModel format:
import tensorflow as tf
import tensorflow_model_analysis as tfma
# Load a SavedModel
model = tf.keras.models.load_model('saved_model/my_model')
# Or load from a path
eval_shared_model = tfma.default_eval_shared_model(
eval_saved_model_path='path/to/saved_model',
tags=[tf.saved_model.SERVING]
)Defining Evaluation Config
eval_config = tfma.EvalConfig(
model_specs=[
tfma.ModelSpec(
label_key='income_bracket', # Your target column
prediction_key='output_0' # Model output tensor name
)
],
slicing_specs=[
# Overall metrics (no slicing)
tfma.SlicingSpec(),
# Slice by age group
tfma.SlicingSpec(feature_keys=['age_group']),
# Slice by gender
tfma.SlicingSpec(feature_keys=['gender']),
# Cross-slice: age × gender
tfma.SlicingSpec(feature_keys=['age_group', 'gender']),
# Slice by specific feature value
tfma.SlicingSpec(
feature_values={'country': 'US'}
)
],
metrics_specs=[
tfma.MetricsSpec(metrics=[
tfma.MetricConfig(class_name='BinaryAccuracy'),
tfma.MetricConfig(class_name='AUC'),
tfma.MetricConfig(class_name='Precision'),
tfma.MetricConfig(class_name='Recall'),
tfma.MetricConfig(class_name='ExampleCount')
])
]
)Running Evaluation
import apache_beam as beam
# Run evaluation on test data
with beam.Pipeline() as pipeline:
eval_result = (
pipeline
| 'ReadData' >> beam.io.ReadFromTFRecord(
'gs://your-bucket/test-data/test-*.tfrecords'
)
| 'EvaluateModel' >> tfma.run_model_analysis(
eval_shared_model=eval_shared_model,
eval_config=eval_config,
output_path='./eval_output'
)
)Testing Slice-Based Performance
After evaluation, write assertions against the slice metrics:
import tensorflow_model_analysis as tfma
def load_eval_results(eval_path: str) -> tfma.EvalResult:
return tfma.load_eval_result(eval_path)
def test_overall_accuracy_meets_threshold():
"""Overall model accuracy must be >= 0.82."""
result = load_eval_results('./eval_output')
# Get overall metrics (empty slicing spec = overall)
overall_metrics = result.get_metrics_for_all_slices()
overall = next(
m for s, m in overall_metrics
if not s # Empty slice = overall
)
accuracy = overall['binary_accuracy']['doubleValue']
assert accuracy >= 0.82, f"Overall accuracy {accuracy:.3f} below threshold 0.82"
def test_accuracy_consistent_across_gender_slices():
"""Accuracy gap between gender groups must be < 5 percentage points."""
result = load_eval_results('./eval_output')
gender_metrics = {}
for slice_key, metrics in result.get_metrics_for_all_slices():
for single_slice in slice_key:
if 'gender' in single_slice.get('single_slice_spec', {}).get('column', ''):
value = single_slice['single_slice_spec']['bytes_value'].decode()
gender_metrics[value] = metrics['binary_accuracy']['doubleValue']
if len(gender_metrics) >= 2:
accuracies = list(gender_metrics.values())
gap = max(accuracies) - min(accuracies)
assert gap < 0.05, \
f"Accuracy gap across gender slices is {gap:.3f} — exceeds 0.05 fairness threshold. Slices: {gender_metrics}"
def test_minimum_examples_per_slice():
"""Every evaluated slice must have at least 100 examples for reliable metrics."""
result = load_eval_results('./eval_output')
for slice_key, metrics in result.get_metrics_for_all_slices():
if slice_key: # Skip overall
example_count = metrics.get('example_count', {}).get('doubleValue', 0)
slice_str = str(slice_key)
assert example_count >= 100, \
f"Slice {slice_str} has only {example_count} examples — metrics unreliable"Fairness Indicators
TensorFlow Fairness Indicators extends TFMA with equality-of-opportunity metrics:
pip install tensorflow_model_remediation fairness-indicators
from fairness_indicators import example_pb2
import fairness_indicators
fairness_eval_config = tfma.EvalConfig(
model_specs=[tfma.ModelSpec(label_key='label')],
slicing_specs=[
tfma.SlicingSpec(),
tfma.SlicingSpec(feature_keys=['race']),
tfma.SlicingSpec(feature_keys=['gender']),
],
metrics_specs=[
tfma.MetricsSpec(metrics=[
tfma.MetricConfig(class_name='FairnessIndicators',
config='{"thresholds": [0.25, 0.5, 0.75]}')
])
]
)def test_equal_opportunity_across_race_slices():
"""
Equal opportunity: True Positive Rate (TPR) must not vary by more than 10%
across racial subgroups.
"""
result = load_eval_results('./fairness_eval_output')
tpr_by_race = {}
for slice_key, metrics in result.get_metrics_for_all_slices():
# Extract race-sliced metrics
if is_race_slice(slice_key):
race_value = get_slice_value(slice_key)
# TPR = TP / (TP + FN) = Recall for positive class
tpr = metrics.get('recall', {}).get('doubleValue', None)
if tpr is not None:
tpr_by_race[race_value] = tpr
if len(tpr_by_race) >= 2:
tprs = list(tpr_by_race.values())
tpr_gap = max(tprs) - min(tprs)
assert tpr_gap <= 0.10, \
f"TPR gap across race slices is {tpr_gap:.3f}, exceeds 0.10. Values: {tpr_by_race}"Model Blessing for Promotion
TFMA's blessing mechanism compares candidate vs. baseline model:
def test_candidate_model_is_better_than_baseline():
"""Candidate model must outperform the current production model."""
candidate_result = load_eval_results('./candidate_eval')
baseline_result = load_eval_results('./baseline_eval')
def get_auc(result):
metrics = result.get_metrics_for_all_slices()
overall = next(m for s, m in metrics if not s)
return overall['auc']['doubleValue']
candidate_auc = get_auc(candidate_result)
baseline_auc = get_auc(baseline_result)
improvement = candidate_auc - baseline_auc
assert improvement >= 0.0, \
f"Candidate AUC ({candidate_auc:.4f}) is worse than baseline ({baseline_auc:.4f})"
print(f"AUC improvement: +{improvement:.4f} ({baseline_auc:.4f} → {candidate_auc:.4f})")In TFX pipelines, the Evaluator component handles blessing automatically:
from tfx.components import Evaluator
evaluator = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
eval_config=eval_config
)
# Pusher only runs if model is blessed
from tfx.components import Pusher
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=evaluator.outputs['blessing'],
push_destination=tfx.proto.PushDestination(
filesystem=tfx.proto.PushDestination.Filesystem(
base_directory=serving_model_dir
)
)
)Testing TFX Pipeline Components
Test individual TFX components in isolation:
from tfx.components import StatisticsGen, SchemaGen, ExampleValidator
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
def test_statistics_generation():
"""StatisticsGen must run without errors on test data."""
context = InteractiveContext()
example_gen = context.run(
ExampleGen(input_base='./test_data')
)
stats_gen = context.run(
StatisticsGen(examples=example_gen.outputs['examples'])
)
# Verify statistics file was generated
stats = tfdv.load_statistics(
stats_gen.outputs['statistics'].get()[0].uri + '/Split-train/FeatureStats.pb'
)
assert stats is not None
assert len(stats.datasets) > 0
def test_schema_validation_detects_anomalies():
"""ExampleValidator must flag data that violates the schema."""
# Load a schema with known constraints
schema = tfdv.load_schema_text('./schemas/feature_schema.pbtxt')
# Create data that violates schema (e.g., missing required feature)
anomalous_data = ... # Create TFRecord with missing feature
# Run validation
anomalies = tfdv.validate_statistics(
statistics=tfdv.generate_statistics_from_dataframe(anomalous_df),
schema=schema
)
assert len(anomalies.anomaly_info) > 0, "Validator should detect anomalies"CI Integration with TFX
name: TFX Model Validation
on:
schedule:
- cron: '0 2 * * *' # Nightly evaluation
jobs:
evaluate:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install TFX and TFMA
run: pip install tensorflow-model-analysis tfx apache-beam
- name: Run TFMA evaluation
env:
GCS_BUCKET: ${{ secrets.GCS_BUCKET }}
run: python scripts/run_evaluation.py
- name: Test slice metrics
run: pytest tests/model_validation/ -v
- name: Bless and push if passing
if: success() && github.ref == 'refs/heads/main'
run: python scripts/bless_and_push.pyMonitoring Model Performance
TFMA runs at evaluation time. For continuous production monitoring:
# Monitor your TF Serving endpoint
helpmetest health tf-serving-api 5m
<span class="hljs-comment"># Alert if prediction volume drops (model stopped being called)
helpmetest health ml-prediction-volume 5mSummary
TFMA brings rigor to ML model validation through:
- Slice-based evaluation — overall accuracy hides subgroup performance gaps
- Fairness metrics — equality of opportunity requires explicit measurement
- Model blessing — compare candidate vs. baseline before promotion
- TFX integration — evaluation gates deployment in production pipelines
The principle is the same as software testing: don't ship without running your checks. For ML, those checks include not just "does it work" but "does it work fairly and consistently across all user segments."