Creating Component Tests for Spark Applications

One of the main engineering challenges faced by the Empathy.co Data Team is creating robust tests for our Spark applications. Since these applications are constantly evolving, as for any application, we needed a way to ensure changes wouldn’t break the code; a guarantee that the output from our jobs would remain the same when refactored or when the input schema of the data is changed.

The biggest hurdle here was determining how to create component tests that check two key boxes:

  1. The aggregated results are as expected.
  2. The output columns of our DataFrame are the same and the schema remains unbroken.

Our Spark project architecture is based on the Spring Boot framework. We use this framework to facilitate the arguments we are passing to our application via environment variables or command line arguments. For that reason, all of our jobs follow the same structure, which uses configuration beans to read the configuration properties. All the jobs implement the abstract run method from the ApplicationRunner spring boot class. So, the final objective is to test the run method from our applications to ensure the aggregation results are returned as expected. Individual methods can also be tested with unit tests, but that won’t be covered here, as it is out of the scope of this blog post.

Let’s imagine that the structure of one of the Spark Jobs we want to test is the following:

@SpringBootApplication(
 exclude = Array(classOf[MongoAutoConfiguration]),
 scanBasePackages = Array[String]("my.configuration.package")
)

@ConfigurationProperties("job-name")
Class MyBatch @Autowired() (
   @BeanProperty var sparkConfig: SparkConfig,
   @BeanProperty var inputConfig: InputConfig
) extends Serializable
 with ApplicationRunner {

 @BeanProperty
 var outputPath: String = _

 override def run(args: ApplicationArguments): Unit = {

   val inputDf = readParquet(
     sparkConfig.sparkSession,
     inputConfig.path,
     inputConfig.getStartDate,
     inputConfig.getEndDate
   )

   parquetDF = letTheJobPerformItsJob()...

   if (!Option(outputPath).forall(_.isEmpty)) {
       parquetDF.save(
           fullOutputPath,
           partitionKeys = Seq(RawSchemaConstants.Yyyy, RawSchemaConstants.Mm) 
       )
}

Next, we will be using the Mockito framework to write the test code. Let’s start the test by creating the Spark session:

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.lit
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.anyString
import org.mockito.ArgumentMatchers.isNull
import org.mockito.MockitoSugar
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.springframework.boot.DefaultApplicationArguments

class MyBatchTest extends AnyFlatSpec with Matchers with MockitoSugar {

 val spark = SparkSession.builder
   .master("local[*]")
   .getOrCreate()

 import spark.implicits._

}

All of our Spark batch jobs follow the same working order: first, the source DataFrame is read from a bucket. Then, we make some aggregations, and finally, the results are either saved to a different bucket or to a MongoDB collection, or both.

The next step is to define the input DataFrame that would be read from the source bucket during a real execution, which would need to contain the data required to test the different scenarios that could arise (such as corner cases). The content of this DataFrame depends on each test logic.

To make the code more readable, we first need to define some prefixes/aliases (SV, SFV) to group into objects the constant values that we are using repeatedly. Also, the code shown here is a reduced version of the DataFrame; not all the rows that are present in the source DataFrame of the test are shown for readability purposes. We could define several DataFrames inside different tests to check individual scenarios and use a larger one to test the overall aggregation with all the scenarios together, but all of them will follow this same procedure:

val testData = DataFrameGenerator.generateSampleDataFrame(
 SV.Year,
 SV.Month,
 SV.Day,
 Seq(
   (SV.St1, Sinks.Q, SFV.LangEngScopeDesktop),
   (SV.St1, Sinks.C, SFV.LangEngScopeDesktop),
   (SV.St1, Sinks.C, SFV.LangEngScopeDesktop),
   [...]
 ).toDF(
   SourceColumnNames.Instance,
   SourceColumnNames.Sink,
   SourceColumnNames.Filters
 )
)

We define an empty dataset with the output that is written to MongoDB, to mock the method since we are not interested in writing any results to MongoDB.

val emptyMongoInstanceDFs = Map[WriteConfig, DataFrame]()

After that, we need to create a mock object for the methods that read the source data and write the results, which in our case are located in an implicit class called DataFrameMethods that extends the Spark DataFrame object. We also define a Mockito ArgumentCaptor object that will allow us to capture the result of the Spark job, which is then passed to one of the methods that performs the write operation. This code is written inside the test code:

"My Batch" should "work" in {

 val dfMethodsMock = mock[DataFrameMethods]
 val dataFrameCaptor = ArgumentCaptor.forClass(classOf[DataFrame])

}

Now, we instantiate the job we want to test and create a spy object to mock some methods:

val myBatch = new MyBatch(
 SparkTestConfig.getSparkConfig,
 SparkTestConfig.getInputConfig
)
val spyBatch = spy(myBatch)

The input beans that the batch receives as arguments (Spark session, input config, etc) can be created as shown below. Our batch receives more input beans like the output parameters, but for the sake of simplicity, they are omitted from this snippet:

def getSparkConfig: SparkConfig = {
 val sparkConfig = new SparkConfig()
 sparkConfig.useLocalMaster = true
 sparkConfig.construct()
 sparkConfig
}


def getInputConfig: InputConfig = {
 val inputConfig = new InputConfig()
 inputConfig.path = "inputPath"
 inputConfig.timeZone = "UTC"
 inputConfig.date = "2022-01-01"
 inputConfig
}

The next step is to mock the read source data and write result methods. The read method should return the input DataFrame that we previously created. For the parquet write method, we set it to do nothing. The method that generates the documents written to MongoDB is also mocked to avoid returning results to just any database. We only want to capture the argument with the resulting DataFrame when the method is called.

doReturn(testData)
 .when(spyBatch)
 .readParquet(any[SparkSession],
ArgumentMatchers.eq("inputPath"),
ArgumentMatchers.eq(ZonedDateTime.of(LocalDate.parse("2022-01-01"), LocalTime.MIN, ZoneId.of("UTC"))),
ArgumentMatchers.eq(ZonedDateTime.of(LocalDate.parse("2022-01-02"), LocalTime.MIN, ZoneId.of("UTC")))
)

doNothing
 .when(dfMethodsMock)
 .save(ArgumentMatchers.eq("outputPath"), ArgumentMatchers.eq(Seq("yyyy", "mm")))

doReturn(emptyMongoInstanceDFs)
 .when(spyBatch)
 .generateInstanceDFs(any[DataFrame], isNull[String], any[DatabaseConfig])

Here comes the most interesting part of the test: the batch code is executed by calling the run method and the resulting DataFrame is captured with the argument captor by using the method that generates the documents written to MongoDB. We mocked this method so it generates an empty collection, but we need to capture the input argument to get the Spark job results. In this case, we are using the method that saves the result into a MongoDB collection, but we could just as well use the one that writes the result into a bucket in parquet format.

spyBatch.run(new DefaultApplicationArguments(""))

verify(spyBatch)
 .generateInstanceDFs(
   dataFrameCaptor.capture(),
   any(),
   any[DatabaseConfig]
 )

val result: DataFrame = dataFrameCaptor.getValue.asInstanceOf[DataFrame].cache()

The last part of the test consists of checking the result. There are several ways to perform this check. One consists of collecting the results and comparing the rows. Although we are working with small DataFrames, we prefer to do this in a distributed way by comparing the DataFrames without collecting the data into the Spark driver.

So, let’s define the expected DataFrame. The key point here is to create the DataFrame with the columns in the same order they are aggregated by the job; otherwise, the check will fail:


val expectedData = Seq(
 (SV.St1, SFV.EmptyFilterStr, 1, 1, 3),
 (SV.St1, SFV.LangEngStr, 1, 2, 1),
 [...]
).toDF(
 FinalColumnNames.Instance,
 FinalColumnNames.Filters,
 FinalColumnNames.QueryCount
)

We check to see that the total number of rows matches and that the DataFrames are equal by subtracting one from the other and checking that the results are empty:

assert(10 === result.count())
assert(expectedData.except(result).isEmpty)

We know that our jobs do not produce duplicated rows (exactly equal rows), so this check is enough to ensure that the expected and resulting DataFrames are the same. With repeated rows, we would need to do more checks, such as the number of times each row is repeated. We could subtract the DataFrames in the inverse order and perform the same check to ensure we are confident with the results of the test.

assert(result.except(expectedData).isEmpty)

To summarize, these are the steps to follow for Spark application component testing:

  1. Mock the beans required by the application and instantiate the Spark session and the Spark application.
  2. Mock the read method to use a DataFrame defined within the test and then write the method so that it does not perform any write operation.
  3. Create a spy for the application object so it can mock the read and write (result) methods.
  4. Create an argument captor to retrieve the result of the Spark application, which is passed via parameter to the write method.
  5. Run the code and capture the resulting DataFrame.
  6. Define the expected DataFrame with the columns defined in the same order as calculated by the Spark application.
  7. Compare the expected result with the DataFrame that was retrieved using the argument captor.

All together, the test code should appear like this:

import com.eb.data.batch.{SampleValues => SV}
import com.eb.data.batch.{SampleFiltersValues => SFV}
import com.eb.data.batch.{SampleSinks => Sinks}
import com.eb.data.batch.config.DatabaseConfig
import com.eb.data.batch.io.df.DataFrame.DataFrameMethods
import com.eb.data.batch.DataFrameGenerator
import com.eb.data.batch.FinalColumnNames
import com.eb.data.batch.SourceColumnNames
import com.eb.data.batch.SparkTestConfig
import com.eb.data.batch.UDFUtils
import com.mongodb.spark.config.WriteConfig
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.lit
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.any
import org.mockito.ArgumentMatchers.anyString
import org.mockito.ArgumentMatchers.isNull
import org.mockito.MockitoSugar
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.springframework.boot.DefaultApplicationArguments

import java.time.ZonedDateTime

class MyBatchTest extends AnyFlatSpec with Matchers with MockitoSugar {

 val spark = SparkSession.builder
   .master("local[*]")
   .getOrCreate()

 import spark.implicits._

 val testData = DataFrameGenerator.generateSampleDataFrame(
   SV.Year,
   SV.Month,
   SV.Day,
   Seq(
   (SV.St1, Sinks.Q, SFV.LangEngScopeDesktop),
   (SV.St1, Sinks.C, SFV.LangEngScopeDesktop),
   (SV.St1, Sinks.C, SFV.LangEngScopeDesktop),
   [...]
  )
)
   ).toDF(
     SourceColumnNames.Instance,
     SourceColumnNames.Sink,
     SourceColumnNames.Filters
   )
 )

 val emptyMongoInstanceDFs = Map[WriteConfig, DataFrame]()

 "My batch" should "work" in {

   val dfMethodsMock = mock[DataFrameMethods]
   val dataFrameCaptor = ArgumentCaptor.forClass(classOf[DataFrame])

   val myBatch = new MyBatch(
     SparkTestConfig.getSparkConfig,
     SparkTestConfig.getInputConfig
   )
   val spyBatch = spy(myBatch)

   doReturn(testData)
     .when(spyBatch)
     .readParquet(any[SparkSession],
ArgumentMatchers.eq("inputPath"),
ArgumentMatchers.eq(ZonedDateTime.of(LocalDate.parse("2022-01-01"), LocalTime.MIN, ZoneId.of("UTC"))),
ArgumentMatchers.eq(ZonedDateTime.of(LocalDate.parse("2022-01-02"), LocalTime.MIN, ZoneId.of("UTC")))
)


   doNothing
 .when(dfMethodsMock)
 .save(ArgumentMatchers.eq("outputPath"), ArgumentMatchers.eq(Seq("yyyy", "mm")))

   doReturn(emptyMongoInstanceDFs)
     .when(spyBatch)
     .generateInstanceDFs(any[DataFrame], isNull[String], any[DatabaseConfig])

   spyBatch.run(new DefaultApplicationArguments(""))

   verify(spyBatch)
     .generateInstanceDFs(
       dataFrameCaptor.capture(),
       any(),
       any[DatabaseConfig]
     )

   val result: DataFrame = dataFrameCaptor.getValue.asInstanceOf[DataFrame].cache()

   assert(10 === result.count())

   val expectedData = Seq(
     (SV.St1, SFV.EmptyFilterStr, 1, 1, 3),
     (SV.St1, SFV.LangEngStr, 1, 2, 1),
     [...]
   ).toDF(
     FinalColumnNames.Instance,
     FinalColumnNames.Filters,
     FinalColumnNames.QueryCount
   )

   assert(10 === result.count())
   assert(expectedData.except(result).isEmpty)
   assert(result.except(expectedData).isEmpty)
 }
}

In a follow-up blog post, we will explore how to improve the efficiency of the Spark session creation and reduce the execution time when performing a battery of tests during the same execution. For now, we hope this serves as a helpful guide to performing component testing in Spark. As always, if you have any questions or comments, please feel free to reach out!