Testing Apache Spark Applications with ScalaTest
Apache Spark applications are notoriously hard to test. The framework was designed for distributed execution at scale, but your tests need to run fast and locally. The good news: Spark runs in local mode, and with the right setup, you can have a test suite that gives real confidence without needing a cluster.
The Core Challenge
Spark tests are slow compared to typical unit tests. A SparkSession takes several seconds to start. The solution is to share one SparkSession across all tests in your suite and run tests within that shared session.
Setting Up a Shared SparkSession
The standard pattern uses SharedSparkSession — a trait that starts Spark once per test suite:
import org.apache.spark.sql.SparkSession
trait SharedSparkSession {
lazy val spark: SparkSession = SparkSession.builder()
.appName("test")
.master("local[*]")
.config("spark.sql.shuffle.partitions", "2") // Critical for test speed
.config("spark.default.parallelism", "2")
.getOrCreate()
lazy val sc = spark.sparkContext
}Key configuration: spark.sql.shuffle.partitions = 2. The default is 200, which means every shuffle operation creates 200 tasks. In tests with small data, this makes things 10-50x slower than necessary.
Writing DataFrame Tests
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
class SalesTransformerSpec extends AnyFlatSpec
with Matchers
with SharedSparkSession {
import spark.implicits._
"SalesTransformer" should "compute revenue per region" in {
val input = Seq(
("North", "Widget", 100.0, 5),
("South", "Gadget", 50.0, 10),
("North", "Gadget", 50.0, 3)
).toDF("region", "product", "price", "quantity")
val result = SalesTransformer.revenueByRegion(input)
result.count() should be(2)
val northRevenue = result
.filter($"region" === "North")
.select($"total_revenue")
.as[Double]
.collect()
.head
northRevenue should be(650.0) // 100*5 + 50*3
}
it should "filter out cancelled orders" in {
val input = Seq(
("ORD-001", "completed", 99.0),
("ORD-002", "cancelled", 49.0),
("ORD-003", "completed", 149.0)
).toDF("order_id", "status", "amount")
val result = SalesTransformer.activeOrders(input)
result.count() should be(2)
result.filter($"status" === "cancelled").count() should be(0)
}
}Comparing DataFrames
Comparing DataFrames isn't straightforward — row order may differ, and schema types matter. A helper function saves significant boilerplate:
def assertDataFrameEquals(
actual: DataFrame,
expected: DataFrame,
checkRowOrder: Boolean = false
): Unit = {
// Schema check
actual.schema should equal(expected.schema)
if (checkRowOrder) {
val actualRows = actual.collect()
val expectedRows = expected.collect()
actualRows should equal(expectedRows)
} else {
// Order-independent comparison
val actualSorted = actual.orderBy(actual.columns.map(col): _*).collect()
val expectedSorted = expected.orderBy(expected.columns.map(col): _*).collect()
actualSorted should equal(expectedSorted)
}
}Or use the spark-testing-base library which provides DataFrameSuiteBase with built-in assertion methods.
Using spark-testing-base
spark-testing-base is the most popular library for Spark testing utilities:
"com.holdenkarau" %% "spark-testing-base" % "3.3.1_1.4.5" % Testimport com.holdenkarau.spark.testing.DataFrameSuiteBase
class ETLPipelineSpec extends AnyFlatSpec
with Matchers
with DataFrameSuiteBase {
test("transforms raw events to sessions") {
val input = sqlContext.createDataFrame(Seq(
("user-1", 1704067200L, "login"),
("user-1", 1704067260L, "purchase"),
("user-2", 1704067200L, "login")
)).toDF("user_id", "timestamp", "event")
val expected = sqlContext.createDataFrame(Seq(
("user-1", 2),
("user-2", 1)
)).toDF("user_id", "event_count")
val result = ETLPipeline.buildSessions(input)
assertDataFrameEquals(expected, result) // Order-independent
}
}Testing RDDs
For lower-level RDD operations:
class WordCountSpec extends AnyFlatSpec with Matchers with SharedSparkSession {
"WordCount" should "count word frequencies" in {
val lines = sc.parallelize(Seq(
"hello world",
"hello scala",
"world of spark"
))
val counts = WordCount.count(lines).collect().toMap
counts("hello") should be(2)
counts("world") should be(2)
counts("scala") should be(1)
}
it should "handle empty input" in {
val empty = sc.parallelize(Seq.empty[String])
WordCount.count(empty).count() should be(0)
}
}Testing Spark SQL
For applications using Spark SQL directly:
class SalesQuerySpec extends AnyFlatSpec with Matchers with SharedSparkSession {
override def beforeEach(): Unit = {
// Register test tables
val salesData = Seq(
(1, "Electronics", 999.99),
(2, "Books", 29.99),
(3, "Electronics", 599.99)
).toDF("id", "category", "price")
salesData.createOrReplaceTempView("sales")
}
"SalesQuery" should "find top categories by revenue" in {
val result = spark.sql("""
SELECT category, SUM(price) as revenue
FROM sales
GROUP BY category
ORDER BY revenue DESC
""")
val topCategory = result.first().getString(0)
topCategory should be("Electronics")
}
}Testing with Delta Lake or Iceberg
If your pipeline writes to Delta or Iceberg tables, use temp directories:
import java.nio.file.Files
class DeltaWriterSpec extends AnyFlatSpec with Matchers with SharedSparkSession {
"DeltaWriter" should "write and read partitioned data" in {
val tempDir = Files.createTempDirectory("delta-test").toString
val data = Seq(
("2024-01-01", "product-1", 100.0),
("2024-01-02", "product-1", 200.0)
).toDF("date", "product_id", "amount")
DeltaWriter.write(data, tempDir, partitionBy = "date")
val read = spark.read.format("delta").load(tempDir)
read.count() should be(2)
val jan2 = read.filter($"date" === "2024-01-02")
jan2.first().getDouble(2) should be(200.0)
}
}Speeding Up Tests
Slow Spark tests are usually fixable:
Share the SparkSession. Starting a new SparkSession per test takes 3-5 seconds. One session shared across all tests is critical.
Set shuffle partitions to 2. spark.sql.shuffle.partitions = 2 is the single biggest speedup for small test data.
Use local mode. master("local[*]") runs without network overhead.
Avoid writing to disk when possible. In-memory DataFrames are faster than writing to Parquet and reading back.
Run test suites in parallel. Different test files can run simultaneously, but tests within a file typically share the session and run sequentially.
# sbt config for parallel test suite execution
Test / parallelExecution := <span class="hljs-literal">trueTesting UDFs
User-defined functions are pure functions — test them as Scala functions first, then as UDFs:
// Test the function directly
def cleanPhone(phone: String): String =
phone.replaceAll("[^0-9]", "")
// Pure function tests
"cleanPhone" should "remove non-digits" in {
cleanPhone("(415) 555-0100") should be("4155550100")
cleanPhone("+1-800-FLOWERS") should be("1800")
}
// Register and test as UDF
"cleanPhone UDF" should "work in DataFrame operations" in {
val cleanPhoneUDF = udf(cleanPhone _)
spark.udf.register("clean_phone", cleanPhoneUDF)
val df = Seq(("Alice", "(415) 555-0100")).toDF("name", "phone")
val result = df.withColumn("clean", cleanPhoneUDF($"phone"))
result.first().getString(2) should be("4155550100")
}Continuous Monitoring
Spark job testing in CI catches logic errors. For monitoring Spark-powered applications in production — checking that batch jobs complete, output data is fresh, and downstream APIs respond correctly — HelpMeTest provides health monitoring with configurable grace periods. If a Spark job that normally finishes in 2 hours hasn't heartbeated in 3, you get an alert.
Summary
Testing Spark applications is achievable with the right patterns:
- Share
SparkSession— start once, share across all tests in a suite shuffle.partitions = 2— single biggest speedup for test datalocal[*]mode — no cluster, no network, fast executionspark-testing-base—assertDataFrameEqualshandles order-independent comparison- Test UDFs as pure functions first, then as registered UDFs
- Temp directories for Delta/Iceberg write-read tests
With these practices, a comprehensive Spark test suite can run in under 2 minutes on a laptop — fast enough to run on every commit.