Testing Apache Spark Applications with ScalaTest

Testing Apache Spark Applications with ScalaTest

Apache Spark applications are notoriously hard to test. The framework was designed for distributed execution at scale, but your tests need to run fast and locally. The good news: Spark runs in local mode, and with the right setup, you can have a test suite that gives real confidence without needing a cluster.

The Core Challenge

Spark tests are slow compared to typical unit tests. A SparkSession takes several seconds to start. The solution is to share one SparkSession across all tests in your suite and run tests within that shared session.

Setting Up a Shared SparkSession

The standard pattern uses SharedSparkSession — a trait that starts Spark once per test suite:

import org.apache.spark.sql.SparkSession

trait SharedSparkSession {
  lazy val spark: SparkSession = SparkSession.builder()
    .appName("test")
    .master("local[*]")
    .config("spark.sql.shuffle.partitions", "2")  // Critical for test speed
    .config("spark.default.parallelism", "2")
    .getOrCreate()

  lazy val sc = spark.sparkContext
}

Key configuration: spark.sql.shuffle.partitions = 2. The default is 200, which means every shuffle operation creates 200 tasks. In tests with small data, this makes things 10-50x slower than necessary.

Writing DataFrame Tests

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class SalesTransformerSpec extends AnyFlatSpec
    with Matchers
    with SharedSparkSession {

  import spark.implicits._

  "SalesTransformer" should "compute revenue per region" in {
    val input = Seq(
      ("North", "Widget", 100.0, 5),
      ("South", "Gadget", 50.0, 10),
      ("North", "Gadget", 50.0, 3)
    ).toDF("region", "product", "price", "quantity")

    val result = SalesTransformer.revenueByRegion(input)

    result.count() should be(2)

    val northRevenue = result
      .filter($"region" === "North")
      .select($"total_revenue")
      .as[Double]
      .collect()
      .head

    northRevenue should be(650.0)  // 100*5 + 50*3
  }

  it should "filter out cancelled orders" in {
    val input = Seq(
      ("ORD-001", "completed", 99.0),
      ("ORD-002", "cancelled", 49.0),
      ("ORD-003", "completed", 149.0)
    ).toDF("order_id", "status", "amount")

    val result = SalesTransformer.activeOrders(input)

    result.count() should be(2)
    result.filter($"status" === "cancelled").count() should be(0)
  }
}

Comparing DataFrames

Comparing DataFrames isn't straightforward — row order may differ, and schema types matter. A helper function saves significant boilerplate:

def assertDataFrameEquals(
  actual: DataFrame,
  expected: DataFrame,
  checkRowOrder: Boolean = false
): Unit = {
  // Schema check
  actual.schema should equal(expected.schema)

  if (checkRowOrder) {
    val actualRows = actual.collect()
    val expectedRows = expected.collect()
    actualRows should equal(expectedRows)
  } else {
    // Order-independent comparison
    val actualSorted = actual.orderBy(actual.columns.map(col): _*).collect()
    val expectedSorted = expected.orderBy(expected.columns.map(col): _*).collect()
    actualSorted should equal(expectedSorted)
  }
}

Or use the spark-testing-base library which provides DataFrameSuiteBase with built-in assertion methods.

Using spark-testing-base

spark-testing-base is the most popular library for Spark testing utilities:

"com.holdenkarau" %% "spark-testing-base" % "3.3.1_1.4.5" % Test
import com.holdenkarau.spark.testing.DataFrameSuiteBase

class ETLPipelineSpec extends AnyFlatSpec
    with Matchers
    with DataFrameSuiteBase {

  test("transforms raw events to sessions") {
    val input = sqlContext.createDataFrame(Seq(
      ("user-1", 1704067200L, "login"),
      ("user-1", 1704067260L, "purchase"),
      ("user-2", 1704067200L, "login")
    )).toDF("user_id", "timestamp", "event")

    val expected = sqlContext.createDataFrame(Seq(
      ("user-1", 2),
      ("user-2", 1)
    )).toDF("user_id", "event_count")

    val result = ETLPipeline.buildSessions(input)

    assertDataFrameEquals(expected, result)  // Order-independent
  }
}

Testing RDDs

For lower-level RDD operations:

class WordCountSpec extends AnyFlatSpec with Matchers with SharedSparkSession {

  "WordCount" should "count word frequencies" in {
    val lines = sc.parallelize(Seq(
      "hello world",
      "hello scala",
      "world of spark"
    ))

    val counts = WordCount.count(lines).collect().toMap

    counts("hello") should be(2)
    counts("world") should be(2)
    counts("scala") should be(1)
  }

  it should "handle empty input" in {
    val empty = sc.parallelize(Seq.empty[String])
    WordCount.count(empty).count() should be(0)
  }
}

Testing Spark SQL

For applications using Spark SQL directly:

class SalesQuerySpec extends AnyFlatSpec with Matchers with SharedSparkSession {

  override def beforeEach(): Unit = {
    // Register test tables
    val salesData = Seq(
      (1, "Electronics", 999.99),
      (2, "Books", 29.99),
      (3, "Electronics", 599.99)
    ).toDF("id", "category", "price")

    salesData.createOrReplaceTempView("sales")
  }

  "SalesQuery" should "find top categories by revenue" in {
    val result = spark.sql("""
      SELECT category, SUM(price) as revenue
      FROM sales
      GROUP BY category
      ORDER BY revenue DESC
    """)

    val topCategory = result.first().getString(0)
    topCategory should be("Electronics")
  }
}

Testing with Delta Lake or Iceberg

If your pipeline writes to Delta or Iceberg tables, use temp directories:

import java.nio.file.Files

class DeltaWriterSpec extends AnyFlatSpec with Matchers with SharedSparkSession {

  "DeltaWriter" should "write and read partitioned data" in {
    val tempDir = Files.createTempDirectory("delta-test").toString

    val data = Seq(
      ("2024-01-01", "product-1", 100.0),
      ("2024-01-02", "product-1", 200.0)
    ).toDF("date", "product_id", "amount")

    DeltaWriter.write(data, tempDir, partitionBy = "date")

    val read = spark.read.format("delta").load(tempDir)
    read.count() should be(2)

    val jan2 = read.filter($"date" === "2024-01-02")
    jan2.first().getDouble(2) should be(200.0)
  }
}

Speeding Up Tests

Slow Spark tests are usually fixable:

Share the SparkSession. Starting a new SparkSession per test takes 3-5 seconds. One session shared across all tests is critical.

Set shuffle partitions to 2. spark.sql.shuffle.partitions = 2 is the single biggest speedup for small test data.

Use local mode. master("local[*]") runs without network overhead.

Avoid writing to disk when possible. In-memory DataFrames are faster than writing to Parquet and reading back.

Run test suites in parallel. Different test files can run simultaneously, but tests within a file typically share the session and run sequentially.

# sbt config for parallel test suite execution
Test / parallelExecution := <span class="hljs-literal">true

Testing UDFs

User-defined functions are pure functions — test them as Scala functions first, then as UDFs:

// Test the function directly
def cleanPhone(phone: String): String =
  phone.replaceAll("[^0-9]", "")

// Pure function tests
"cleanPhone" should "remove non-digits" in {
  cleanPhone("(415) 555-0100") should be("4155550100")
  cleanPhone("+1-800-FLOWERS") should be("1800")
}

// Register and test as UDF
"cleanPhone UDF" should "work in DataFrame operations" in {
  val cleanPhoneUDF = udf(cleanPhone _)
  spark.udf.register("clean_phone", cleanPhoneUDF)

  val df = Seq(("Alice", "(415) 555-0100")).toDF("name", "phone")
  val result = df.withColumn("clean", cleanPhoneUDF($"phone"))
  result.first().getString(2) should be("4155550100")
}

Continuous Monitoring

Spark job testing in CI catches logic errors. For monitoring Spark-powered applications in production — checking that batch jobs complete, output data is fresh, and downstream APIs respond correctly — HelpMeTest provides health monitoring with configurable grace periods. If a Spark job that normally finishes in 2 hours hasn't heartbeated in 3, you get an alert.

Summary

Testing Spark applications is achievable with the right patterns:

  • Share SparkSession — start once, share across all tests in a suite
  • shuffle.partitions = 2 — single biggest speedup for test data
  • local[*] mode — no cluster, no network, fast execution
  • spark-testing-baseassertDataFrameEquals handles order-independent comparison
  • Test UDFs as pure functions first, then as registered UDFs
  • Temp directories for Delta/Iceberg write-read tests

With these practices, a comprehensive Spark test suite can run in under 2 minutes on a laptop — fast enough to run on every commit.

Read more