TensorFlow Model Validation with TFMA: Testing ML Models in Production Pipelines

TensorFlow Model Validation with TFMA: Testing ML Models in Production Pipelines

TensorFlow Model Analysis (TFMA) provides slice-based model evaluation — testing not just overall accuracy but performance across demographic and feature subgroups. This guide covers TFMA setup, fairness testing, model blessing, and integration into TFX pipelines.


What Is TensorFlow Model Analysis?

TFMA is part of the TensorFlow Extended (TFX) ecosystem, designed to answer questions that overall accuracy metrics hide:

  • Does the model perform equally well across user age groups?
  • Does accuracy drop for a specific geographic region?
  • How does model quality compare to the currently-deployed version?

TFMA evaluates models on large datasets in a distributed fashion, computing metrics on arbitrary "slices" — subsets of data defined by feature values.


Installation

pip install tensorflow-model-analysis
pip install apache-beam[gcp]  # For distributed evaluation
pip install tensorflow tfx

Basic TFMA Evaluation

Model Setup

TFMA works with SavedModel format:

import tensorflow as tf
import tensorflow_model_analysis as tfma

# Load a SavedModel
model = tf.keras.models.load_model('saved_model/my_model')

# Or load from a path
eval_shared_model = tfma.default_eval_shared_model(
    eval_saved_model_path='path/to/saved_model',
    tags=[tf.saved_model.SERVING]
)

Defining Evaluation Config

eval_config = tfma.EvalConfig(
    model_specs=[
        tfma.ModelSpec(
            label_key='income_bracket',  # Your target column
            prediction_key='output_0'    # Model output tensor name
        )
    ],
    slicing_specs=[
        # Overall metrics (no slicing)
        tfma.SlicingSpec(),
        
        # Slice by age group
        tfma.SlicingSpec(feature_keys=['age_group']),
        
        # Slice by gender
        tfma.SlicingSpec(feature_keys=['gender']),
        
        # Cross-slice: age × gender
        tfma.SlicingSpec(feature_keys=['age_group', 'gender']),
        
        # Slice by specific feature value
        tfma.SlicingSpec(
            feature_values={'country': 'US'}
        )
    ],
    metrics_specs=[
        tfma.MetricsSpec(metrics=[
            tfma.MetricConfig(class_name='BinaryAccuracy'),
            tfma.MetricConfig(class_name='AUC'),
            tfma.MetricConfig(class_name='Precision'),
            tfma.MetricConfig(class_name='Recall'),
            tfma.MetricConfig(class_name='ExampleCount')
        ])
    ]
)

Running Evaluation

import apache_beam as beam

# Run evaluation on test data
with beam.Pipeline() as pipeline:
    eval_result = (
        pipeline
        | 'ReadData' >> beam.io.ReadFromTFRecord(
            'gs://your-bucket/test-data/test-*.tfrecords'
        )
        | 'EvaluateModel' >> tfma.run_model_analysis(
            eval_shared_model=eval_shared_model,
            eval_config=eval_config,
            output_path='./eval_output'
        )
    )

Testing Slice-Based Performance

After evaluation, write assertions against the slice metrics:

import tensorflow_model_analysis as tfma

def load_eval_results(eval_path: str) -> tfma.EvalResult:
    return tfma.load_eval_result(eval_path)


def test_overall_accuracy_meets_threshold():
    """Overall model accuracy must be >= 0.82."""
    result = load_eval_results('./eval_output')
    
    # Get overall metrics (empty slicing spec = overall)
    overall_metrics = result.get_metrics_for_all_slices()
    overall = next(
        m for s, m in overall_metrics 
        if not s  # Empty slice = overall
    )
    
    accuracy = overall['binary_accuracy']['doubleValue']
    assert accuracy >= 0.82, f"Overall accuracy {accuracy:.3f} below threshold 0.82"


def test_accuracy_consistent_across_gender_slices():
    """Accuracy gap between gender groups must be < 5 percentage points."""
    result = load_eval_results('./eval_output')
    
    gender_metrics = {}
    for slice_key, metrics in result.get_metrics_for_all_slices():
        for single_slice in slice_key:
            if 'gender' in single_slice.get('single_slice_spec', {}).get('column', ''):
                value = single_slice['single_slice_spec']['bytes_value'].decode()
                gender_metrics[value] = metrics['binary_accuracy']['doubleValue']
    
    if len(gender_metrics) >= 2:
        accuracies = list(gender_metrics.values())
        gap = max(accuracies) - min(accuracies)
        assert gap < 0.05, \
            f"Accuracy gap across gender slices is {gap:.3f} — exceeds 0.05 fairness threshold. Slices: {gender_metrics}"


def test_minimum_examples_per_slice():
    """Every evaluated slice must have at least 100 examples for reliable metrics."""
    result = load_eval_results('./eval_output')
    
    for slice_key, metrics in result.get_metrics_for_all_slices():
        if slice_key:  # Skip overall
            example_count = metrics.get('example_count', {}).get('doubleValue', 0)
            slice_str = str(slice_key)
            assert example_count >= 100, \
                f"Slice {slice_str} has only {example_count} examples — metrics unreliable"

Fairness Indicators

TensorFlow Fairness Indicators extends TFMA with equality-of-opportunity metrics:

pip install tensorflow_model_remediation fairness-indicators

from fairness_indicators import example_pb2
import fairness_indicators

fairness_eval_config = tfma.EvalConfig(
    model_specs=[tfma.ModelSpec(label_key='label')],
    slicing_specs=[
        tfma.SlicingSpec(),
        tfma.SlicingSpec(feature_keys=['race']),
        tfma.SlicingSpec(feature_keys=['gender']),
    ],
    metrics_specs=[
        tfma.MetricsSpec(metrics=[
            tfma.MetricConfig(class_name='FairnessIndicators',
                config='{"thresholds": [0.25, 0.5, 0.75]}')
        ])
    ]
)
def test_equal_opportunity_across_race_slices():
    """
    Equal opportunity: True Positive Rate (TPR) must not vary by more than 10% 
    across racial subgroups.
    """
    result = load_eval_results('./fairness_eval_output')
    
    tpr_by_race = {}
    for slice_key, metrics in result.get_metrics_for_all_slices():
        # Extract race-sliced metrics
        if is_race_slice(slice_key):
            race_value = get_slice_value(slice_key)
            # TPR = TP / (TP + FN) = Recall for positive class
            tpr = metrics.get('recall', {}).get('doubleValue', None)
            if tpr is not None:
                tpr_by_race[race_value] = tpr
    
    if len(tpr_by_race) >= 2:
        tprs = list(tpr_by_race.values())
        tpr_gap = max(tprs) - min(tprs)
        assert tpr_gap <= 0.10, \
            f"TPR gap across race slices is {tpr_gap:.3f}, exceeds 0.10. Values: {tpr_by_race}"

Model Blessing for Promotion

TFMA's blessing mechanism compares candidate vs. baseline model:

def test_candidate_model_is_better_than_baseline():
    """Candidate model must outperform the current production model."""
    
    candidate_result = load_eval_results('./candidate_eval')
    baseline_result = load_eval_results('./baseline_eval')
    
    def get_auc(result):
        metrics = result.get_metrics_for_all_slices()
        overall = next(m for s, m in metrics if not s)
        return overall['auc']['doubleValue']
    
    candidate_auc = get_auc(candidate_result)
    baseline_auc = get_auc(baseline_result)
    
    improvement = candidate_auc - baseline_auc
    
    assert improvement >= 0.0, \
        f"Candidate AUC ({candidate_auc:.4f}) is worse than baseline ({baseline_auc:.4f})"
    
    print(f"AUC improvement: +{improvement:.4f} ({baseline_auc:.4f}{candidate_auc:.4f})")

In TFX pipelines, the Evaluator component handles blessing automatically:

from tfx.components import Evaluator

evaluator = Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config
)

# Pusher only runs if model is blessed
from tfx.components import Pusher

pusher = Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=serving_model_dir
        )
    )
)

Testing TFX Pipeline Components

Test individual TFX components in isolation:

from tfx.components import StatisticsGen, SchemaGen, ExampleValidator
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

def test_statistics_generation():
    """StatisticsGen must run without errors on test data."""
    context = InteractiveContext()
    
    example_gen = context.run(
        ExampleGen(input_base='./test_data')
    )
    
    stats_gen = context.run(
        StatisticsGen(examples=example_gen.outputs['examples'])
    )
    
    # Verify statistics file was generated
    stats = tfdv.load_statistics(
        stats_gen.outputs['statistics'].get()[0].uri + '/Split-train/FeatureStats.pb'
    )
    
    assert stats is not None
    assert len(stats.datasets) > 0


def test_schema_validation_detects_anomalies():
    """ExampleValidator must flag data that violates the schema."""
    # Load a schema with known constraints
    schema = tfdv.load_schema_text('./schemas/feature_schema.pbtxt')
    
    # Create data that violates schema (e.g., missing required feature)
    anomalous_data = ... # Create TFRecord with missing feature
    
    # Run validation
    anomalies = tfdv.validate_statistics(
        statistics=tfdv.generate_statistics_from_dataframe(anomalous_df),
        schema=schema
    )
    
    assert len(anomalies.anomaly_info) > 0, "Validator should detect anomalies"

CI Integration with TFX

name: TFX Model Validation
on:
  schedule:
    - cron: '0 2 * * *'  # Nightly evaluation

jobs:
  evaluate:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      
      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.10'
      
      - name: Install TFX and TFMA
        run: pip install tensorflow-model-analysis tfx apache-beam
      
      - name: Run TFMA evaluation
        env:
          GCS_BUCKET: ${{ secrets.GCS_BUCKET }}
        run: python scripts/run_evaluation.py
      
      - name: Test slice metrics
        run: pytest tests/model_validation/ -v
      
      - name: Bless and push if passing
        if: success() && github.ref == 'refs/heads/main'
        run: python scripts/bless_and_push.py

Monitoring Model Performance

TFMA runs at evaluation time. For continuous production monitoring:

# Monitor your TF Serving endpoint
helpmetest health tf-serving-api 5m

<span class="hljs-comment"># Alert if prediction volume drops (model stopped being called)
helpmetest health ml-prediction-volume 5m

Summary

TFMA brings rigor to ML model validation through:

  • Slice-based evaluation — overall accuracy hides subgroup performance gaps
  • Fairness metrics — equality of opportunity requires explicit measurement
  • Model blessing — compare candidate vs. baseline before promotion
  • TFX integration — evaluation gates deployment in production pipelines

The principle is the same as software testing: don't ship without running your checks. For ML, those checks include not just "does it work" but "does it work fairly and consistently across all user segments."

Read more