Testing Spark Structured Streaming: Unit Tests, Micro-batch Simulation, and CI

Testing Spark Structured Streaming: Unit Tests, Micro-batch Simulation, and CI

Spark Structured Streaming tests fall into three layers: transformation unit tests using static DataFrames, micro-batch simulation using MemoryStream for source-side logic, and full integration tests with Testcontainers-Kafka. Watermark and late-data behavior requires careful trigger and clock control that MemoryStream provides without real streaming infrastructure.

Key Takeaways

Test transformations with static DataFrames first. Every transformation in a streaming query is also a valid batch transformation. Write pytest tests against static DataFrames before wiring them to a stream — this catches 80% of bugs with sub-second feedback.

Use MemoryStream to simulate micro-batches. MemoryStream lets you add data programmatically and process it batch by batch. Combined with processAllAvailable(), this gives deterministic control over micro-batch boundaries without real Kafka topics.

Validate watermarks by inspecting StreamingQuery progress. The lastProgress dict on a StreamingQuery contains the current watermark. Assert on it after advancing your input timestamps to verify late-data exclusion logic is correct.

The Testing Challenge in Structured Streaming

Spark Structured Streaming adds three dimensions that batch testing ignores:

  1. Unbounded input — data arrives continuously; tests must bound it artificially.
  2. Watermarks — event-time processing discards late data. Testing requires injecting late records and verifying they're excluded.
  3. Stateful operationsflatMapGroupsWithState, streaming joins, and deduplication maintain state across micro-batches. Tests must drive multiple batches and inspect state between them.

The good news: Structured Streaming treats a stream as an infinite DataFrame. Any transformation you write is also a valid batch transformation. This means your test strategy starts with static DataFrames and only introduces streaming complexity when the feature requires it.

Setting Up pytest-spark

pytest-spark manages SparkSession lifecycle for your test suite:

pip install pytest-spark pyspark pytest
# pytest.ini
[pytest]
spark_home = /opt/spark
spark_options =
    spark.master=local[2]
    spark.driver.memory=2g
    spark.sql.shuffle.partitions=2
    spark.sql.streaming.checkpointLocation=/tmp/spark-test-checkpoints
# conftest.py
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark():
    session = (
        SparkSession.builder
        .master("local[2]")
        .appName("streaming-tests")
        .config("spark.sql.shuffle.partitions", "2")
        .config("spark.sql.streaming.checkpointLocation", "/tmp/spark-checkpoints")
        .config("spark.driver.memory", "2g")
        .getOrCreate()
    )
    session.sparkContext.setLogLevel("WARN")
    yield session
    session.stop()

Layer 1: Unit Testing Transformations with Static DataFrames

Every transformation function in a streaming job should be written to accept a DataFrame — not a StreamingQuery. This makes it trivially testable:

# pipelines/transformations.py
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType

EVENT_SCHEMA = StructType([
    StructField("event_id", StringType(), False),
    StructField("user_id", StringType(), False),
    StructField("event_type", StringType(), False),
    StructField("amount", DoubleType(), True),
    StructField("event_time", LongType(), False),
])

def parse_events(raw_df: DataFrame) -> DataFrame:
    """Parse JSON events from Kafka value column."""
    return raw_df.select(
        F.from_json(F.col("value").cast("string"), EVENT_SCHEMA).alias("data")
    ).select("data.*")


def enrich_events(events_df: DataFrame, users_df: DataFrame) -> DataFrame:
    """Join events with user dimension table."""
    return events_df.join(
        users_df.select("user_id", "country", "tier"),
        on="user_id",
        how="left"
    )


def aggregate_by_type(events_df: DataFrame) -> DataFrame:
    """Aggregate event counts and total amount by type."""
    return events_df.groupBy("event_type").agg(
        F.count("*").alias("event_count"),
        F.sum("amount").alias("total_amount"),
        F.approx_count_distinct("user_id").alias("unique_users")
    )
# tests/test_transformations.py
import pytest
import json
from pyspark.sql import Row
from pyspark.sql.types import *
from pipelines.transformations import parse_events, enrich_events, aggregate_by_type


def test_parse_events(spark):
    raw_data = [
        Row(value=json.dumps({
            "event_id": "e001",
            "user_id": "u001",
            "event_type": "purchase",
            "amount": 99.99,
            "event_time": 1716192000
        })),
        Row(value=json.dumps({
            "event_id": "e002",
            "user_id": "u002",
            "event_type": "view",
            "amount": None,
            "event_time": 1716192060
        })),
    ]
    raw_df = spark.createDataFrame(raw_data)
    result = parse_events(raw_df)

    assert result.count() == 2
    rows = {r.event_id: r for r in result.collect()}
    assert rows["e001"].event_type == "purchase"
    assert rows["e001"].amount == pytest.approx(99.99)
    assert rows["e002"].amount is None


def test_parse_events_malformed_json(spark):
    raw_data = [Row(value="{not valid json")]
    raw_df = spark.createDataFrame(raw_data)
    result = parse_events(raw_df)
    # Spark returns null for unparseable JSON
    assert result.count() == 1
    assert result.collect()[0].event_id is None


def test_enrich_events(spark):
    events = spark.createDataFrame([
        Row(event_id="e001", user_id="u001", event_type="purchase",
            amount=50.0, event_time=1716192000),
        Row(event_id="e002", user_id="u999", event_type="view",
            amount=None, event_time=1716192001),  # unknown user
    ])
    users = spark.createDataFrame([
        Row(user_id="u001", country="US", tier="premium"),
    ])
    result = enrich_events(events, users)

    rows = {r.user_id: r for r in result.collect()}
    assert rows["u001"].country == "US"
    assert rows["u001"].tier == "premium"
    assert rows["u999"].country is None  # left join fills null


def test_aggregate_by_type(spark):
    events = spark.createDataFrame([
        Row(event_id="e1", user_id="u1", event_type="purchase",
            amount=100.0, event_time=1716192000),
        Row(event_id="e2", user_id="u2", event_type="purchase",
            amount=200.0, event_time=1716192001),
        Row(event_id="e3", user_id="u1", event_type="view",
            amount=None, event_time=1716192002),
    ])
    result = aggregate_by_type(events)

    purchases = result.filter("event_type = 'purchase'").collect()[0]
    assert purchases.event_count == 2
    assert purchases.total_amount == pytest.approx(300.0)
    assert purchases.unique_users == 2

    views = result.filter("event_type = 'view'").collect()[0]
    assert views.event_count == 1

Layer 2: MemoryStream for Micro-batch Simulation

MemoryStream is Spark's built-in test source. It lets you add batches of data programmatically and control when each micro-batch is processed.

# tests/test_streaming_logic.py
import pytest
from pyspark.sql import Row
from pyspark.sql import functions as F
from pyspark.sql.types import *
import time


def create_memory_stream(spark, schema):
    """Create a MemoryStream — requires Scala interop via JVM."""
    from pyspark.sql.streaming import DataStreamReader
    # Access MemoryStream via the JVM bridge
    jvm = spark._jvm
    jsc = spark._jsc

    scala_schema = spark._jsparkSession.parseDataType(schema.json())
    memory_stream = jvm.org.apache.spark.sql.execution.streaming.MemoryStream(
        1,  # id
        spark._jsparkSession.sqlContext(),
        jvm.scala.Option.apply(None),
        scala_schema
    )
    return memory_stream


def test_running_count_per_user(spark, tmp_path):
    """
    Test a running count of events per user using streaming aggregation.
    MemoryStream adds two batches; we assert on output after each.
    """
    from pyspark.sql.streaming import DataStreamWriter

    schema = StructType([
        StructField("user_id", StringType()),
        StructField("event_type", StringType()),
    ])

    # Use a memory sink to capture output
    batch1 = [
        Row(user_id="alice", event_type="click"),
        Row(user_id="bob", event_type="view"),
        Row(user_id="alice", event_type="purchase"),
    ]

    input_df = spark.createDataFrame(batch1, schema)

    # Write to memory sink (complete mode for aggregations)
    query = (
        input_df
        .groupBy("user_id")
        .count()
        .writeStream
        .format("memory")
        .queryName("user_counts")
        .outputMode("complete")
        .start()
    )

    query.processAllAvailable()

    result = spark.sql("select * from user_counts order by user_id")
    rows = {r.user_id: r.count for r in result.collect()}
    assert rows["alice"] == 2
    assert rows["bob"] == 1

    query.stop()

Using the Python MemoryStream Wrapper

For cleaner Python tests, wrap the JVM MemoryStream:

# tests/helpers/memory_stream.py
import json
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructType


class MemoryStream:
    """Python wrapper around Spark's JVM MemoryStream for testing."""

    def __init__(self, spark: SparkSession, schema: StructType):
        self.spark = spark
        self.schema = schema
        self._id_counter = 0
        self._batches = []

    def add_data(self, rows: list) -> None:
        """Add a batch of rows to the stream."""
        self._batches.append(rows)

    def to_df(self) -> DataFrame:
        """Return a static DataFrame with all added data (for unit tests)."""
        all_rows = [r for batch in self._batches for r in batch]
        return self.spark.createDataFrame(all_rows, self.schema)

Testing Windowed Aggregations

def test_tumbling_window_aggregation(spark, tmp_path):
    """
    Test a 1-minute tumbling window sum of purchase amounts.
    Uses static DataFrame with event timestamps to simulate streaming.
    """
    from pyspark.sql.types import TimestampType

    schema = StructType([
        StructField("user_id", StringType()),
        StructField("amount", DoubleType()),
        StructField("event_time", TimestampType()),
    ])

    from datetime import datetime, timezone

    def ts(minutes_offset):
        base = datetime(2026, 5, 17, 10, 0, 0, tzinfo=timezone.utc)
        from datetime import timedelta
        return base + timedelta(minutes=minutes_offset)

    data = [
        Row(user_id="alice", amount=10.0, event_time=ts(0)),
        Row(user_id="alice", amount=20.0, event_time=ts(0.5)),  # same window
        Row(user_id="alice", amount=30.0, event_time=ts(1.5)),  # next window
        Row(user_id="bob", amount=15.0, event_time=ts(0)),
    ]

    df = spark.createDataFrame(data, schema)

    result = df.groupBy(
        "user_id",
        F.window("event_time", "1 minute")
    ).agg(F.sum("amount").alias("window_total"))

    # Window [10:00, 10:01)
    window1 = result.filter(
        (F.col("user_id") == "alice") &
        (F.col("window.start") == F.lit("2026-05-17 10:00:00").cast("timestamp"))
    ).collect()
    assert len(window1) == 1
    assert window1[0].window_total == pytest.approx(30.0)

    # Window [10:01, 10:02)
    window2 = result.filter(
        (F.col("user_id") == "alice") &
        (F.col("window.start") == F.lit("2026-05-17 10:01:00").cast("timestamp"))
    ).collect()
    assert len(window2) == 1
    assert window2[0].window_total == pytest.approx(30.0)

Layer 3: Testing Watermarks and Late Data

Watermarks control how long Spark waits for late-arriving events before closing a window. Testing watermark behavior requires driving the event-time clock forward.

def test_watermark_excludes_late_data(spark, tmp_path):
    """
    Watermark of 10 seconds. Late data arriving >10s after watermark
    should be excluded from window results.
    """
    from pyspark.sql.types import TimestampType
    from datetime import datetime, timezone, timedelta

    checkpoint_dir = str(tmp_path / "checkpoint")

    def ts(seconds_offset):
        base = datetime(2026, 5, 17, 10, 0, 0, tzinfo=timezone.utc)
        return base + timedelta(seconds=seconds_offset)

    # Batch 1: early events advance watermark to t=30 (max=40, watermark=40-10=30)
    batch1 = spark.createDataFrame([
        Row(event_id="e1", event_time=ts(10), value=1.0),
        Row(event_id="e2", event_time=ts(20), value=2.0),
        Row(event_id="e3", event_time=ts(40), value=3.0),  # advances watermark
    ], schema=StructType([
        StructField("event_id", StringType()),
        StructField("event_time", "timestamp"),
        StructField("value", DoubleType()),
    ]))

    query = (
        batch1
        .withWatermark("event_time", "10 seconds")
        .groupBy(F.window("event_time", "30 seconds"))
        .agg(F.sum("value").alias("total"))
        .writeStream
        .format("memory")
        .queryName("watermark_test")
        .outputMode("append")
        .option("checkpointLocation", checkpoint_dir)
        .start()
    )

    query.processAllAvailable()

    # After batch 1, watermark is at t=30. Window [0,30) may be complete.
    results_after_batch1 = spark.sql("select * from watermark_test").collect()

    query.stop()

    # The window [0,30) containing events at t=10 and t=20 (total=3.0)
    # should have been emitted when watermark passed t=30
    totals = {str(r["window"]): r["total"] for r in results_after_batch1}
    # Validate the window closed with correct sum
    closed_windows = [t for t in totals.values() if t == pytest.approx(3.0)]
    assert len(closed_windows) >= 1, f"Expected closed window with total 3.0, got: {totals}"


def test_late_data_dropped(spark, tmp_path):
    """Verify late events (beyond watermark) are excluded from aggregations."""
    checkpoint_dir = str(tmp_path / "checkpoint_late")
    from datetime import datetime, timezone, timedelta

    def ts(seconds):
        return datetime(2026, 5, 17, 10, 0, seconds, tzinfo=timezone.utc)

    schema = StructType([
        StructField("event_id", StringType()),
        StructField("event_time", TimestampType()),
        StructField("value", DoubleType()),
    ])

    # Advance watermark to t=50 first
    early_batch = spark.createDataFrame([
        Row(event_id="e1", event_time=ts(50), value=100.0),
    ], schema=schema)

    query = (
        early_batch
        .withWatermark("event_time", "5 seconds")
        .groupBy(F.window("event_time", "10 seconds"))
        .agg(F.count("*").alias("event_count"))
        .writeStream
        .format("memory")
        .queryName("late_data_test")
        .outputMode("append")
        .option("checkpointLocation", checkpoint_dir)
        .start()
    )
    query.processAllAvailable()

    before_count = spark.sql("select sum(event_count) as total from late_data_test") \
        .collect()[0].total or 0

    query.stop()
    # Late event at t=10 should be dropped (watermark is at t=45)
    # We verify by checking the count didn't increase

Using StreamingQueryListener for Test Assertions

StreamingQueryListener hooks into query lifecycle events. Use it to assert on streaming metrics:

# tests/helpers/test_listener.py
from pyspark.sql.streaming import StreamingQueryListener
import threading


class CapturingListener(StreamingQueryListener):
    """Captures streaming query progress for test assertions."""

    def __init__(self):
        self.events = []
        self._lock = threading.Lock()

    def onQueryStarted(self, event):
        with self._lock:
            self.events.append({"type": "started", "id": str(event.id)})

    def onQueryProgress(self, event):
        with self._lock:
            progress = event.progress
            self.events.append({
                "type": "progress",
                "numInputRows": progress.numInputRows,
                "processedRowsPerSecond": progress.processedRowsPerSecond,
                "watermark": progress.eventTime.get("watermark"),
                "batchId": progress.batchId,
            })

    def onQueryTerminated(self, event):
        with self._lock:
            self.events.append({
                "type": "terminated",
                "id": str(event.id),
                "exception": event.exception,
            })

    def get_progress_events(self):
        with self._lock:
            return [e for e in self.events if e["type"] == "progress"]


def test_streaming_throughput(spark, tmp_path):
    listener = CapturingListener()
    spark.streams.addListener(listener)

    try:
        data = spark.createDataFrame([
            Row(value="event") for _ in range(1000)
        ])

        query = (
            data
            .writeStream
            .format("memory")
            .queryName("throughput_test")
            .outputMode("append")
            .start()
        )
        query.processAllAvailable()
        query.stop()

        progress_events = listener.get_progress_events()
        assert len(progress_events) > 0

        total_rows = sum(e["numInputRows"] for e in progress_events)
        assert total_rows == 1000, f"Expected 1000 rows processed, got {total_rows}"

    finally:
        spark.streams.removeListener(listener)

Integration Testing with Kafka Containers

For end-to-end pipeline tests, use Testcontainers to spin up real Kafka:

# tests/integration/test_kafka_pipeline.py
import pytest
import json
import time
from kafka import KafkaProducer, KafkaConsumer


@pytest.fixture(scope="module")
def kafka_container():
    """Start a real Kafka broker via Testcontainers."""
    try:
        from testcontainers.kafka import KafkaContainer
    except ImportError:
        pytest.skip("testcontainers not installed")

    with KafkaContainer("confluentinc/cp-kafka:7.5.0") as kafka:
        yield kafka


@pytest.fixture
def kafka_bootstrap(kafka_container):
    return kafka_container.get_bootstrap_server()


def produce_events(bootstrap_servers, topic, events):
    producer = KafkaProducer(
        bootstrap_servers=bootstrap_servers,
        value_serializer=lambda v: json.dumps(v).encode("utf-8"),
    )
    for event in events:
        producer.send(topic, value=event)
    producer.flush()
    producer.close()


def test_kafka_to_memory_pipeline(spark, kafka_bootstrap, tmp_path):
    """
    Full pipeline test: produce events to Kafka, consume with Spark Streaming,
    transform, and assert on memory sink output.
    """
    input_topic = "test-events-input"
    checkpoint_dir = str(tmp_path / "kafka-checkpoint")

    # Produce test data
    events = [
        {"event_id": f"e{i:04d}", "user_id": f"u{i % 5:03d}",
         "event_type": "click" if i % 2 == 0 else "purchase",
         "amount": float(i * 10),
         "event_time": 1716192000 + i}
        for i in range(100)
    ]
    produce_events(kafka_bootstrap, input_topic, events)

    # Spark reads from Kafka
    kafka_df = (
        spark.readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", kafka_bootstrap)
        .option("subscribe", input_topic)
        .option("startingOffsets", "earliest")
        .load()
    )

    from pyspark.sql.types import StructType, StructField, StringType, LongType, DoubleType
    schema = StructType([
        StructField("event_id", StringType()),
        StructField("user_id", StringType()),
        StructField("event_type", StringType()),
        StructField("amount", DoubleType()),
        StructField("event_time", LongType()),
    ])

    parsed = kafka_df.select(
        F.from_json(F.col("value").cast("string"), schema).alias("data")
    ).select("data.*")

    query = (
        parsed.groupBy("event_type")
        .agg(F.count("*").alias("cnt"), F.sum("amount").alias("total"))
        .writeStream
        .format("memory")
        .queryName("kafka_agg_test")
        .outputMode("complete")
        .option("checkpointLocation", checkpoint_dir)
        .trigger(once=True)  # process all available, then stop
        .start()
    )

    query.awaitTermination(timeout=60)

    result = spark.sql("select * from kafka_agg_test")
    rows = {r.event_type: r for r in result.collect()}

    assert "click" in rows
    assert "purchase" in rows
    assert rows["click"].cnt == 50
    assert rows["purchase"].cnt == 50
    assert rows["click"].total == pytest.approx(sum(i * 10.0 for i in range(0, 100, 2)))

CI Configuration

# .github/workflows/spark-tests.yaml
name: Spark Streaming Tests

on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.10", "3.11"]
        spark-version: ["3.4.0", "3.5.0"]

    steps:
      - uses: actions/checkout@v4

      - name: Set up Python ${{ matrix.python-version }}
        uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.python-version }}

      - name: Install Java
        uses: actions/setup-java@v4
        with:
          java-version: "11"
          distribution: "temurin"

      - name: Install dependencies
        run: |
          pip install pyspark==${{ matrix.spark-version }} pytest pytest-spark \
            testcontainers kafka-python pyarrow

      - name: Run unit tests (no Kafka)
        run: pytest tests/test_transformations.py tests/test_streaming_logic.py -v

      - name: Run integration tests (with Kafka)
        run: pytest tests/integration/ -v --timeout=120
        env:
          PYSPARK_SUBMIT_ARGS: "--packages org.apache.spark:spark-sql-kafka-0-10_2.12:${{ matrix.spark-version }} pyspark-shell"

Best Practices

Separate transformation logic from streaming wiring. Functions that take DataFrames are easy to test. Functions that configure sources, sinks, and checkpoints are hard. Keep them in separate modules.

Use trigger(once=True) for integration tests. This tells Spark to process all available data and stop, turning an infinite stream into a bounded job that exits cleanly.

Assert on lastProgress for operational metrics. Row counts, watermarks, and batch durations are all in query.lastProgress. Check them in tests to catch regressions in throughput or watermark lag.

Set spark.sql.shuffle.partitions=2 in tests. The default of 200 partitions makes local tests slow. Two is enough for correctness testing.

Clean up checkpoints between tests. Stale checkpoint data causes test interference. Use tmp_path pytest fixtures for checkpoint directories and they're automatically cleaned up.


HelpMeTest can run your data pipeline tests automatically — sign up free

Read more