Testing patterns with PySpark


Introduction

I recently had to set up a PySpark pipeline that performed a number of data transformations (dimensionally reducing the data to focus on rollups of factors we wanted to target and streamline for downstream querying and analysis). Writing tests for this was not immediately intuitive. In the end, the pattern follows that of most pytest tests that you would see in a standard Python library, though. This post documents a simplified, generic example of a pipeline job and how to break apart step transformations and test them independent of the parent operation’s complete step flow.

Main function overview

The generic method will have a few key steps. First it will format a query, then it will submit the query (with spark.sql) and return a Spark DataFrame with the result. At that point, it will run a series transformations on the DataFrame and, finally, it will write the results to some parameterized destination (which in production might be an S3 bucket, for example).

We can describe the steps involved with the following pseudocode:

def main(params):
    format_query
    run_query
    transform_step_1
    transform_step_2
    write_result

Main function steps details

Fleshing this out, we can describe the format_query step with the following method:

def _create_query_str(query_param: str):
    """Format and return sql string template with filtering parameter values."""
    return """
        SELECT foo, bar, partition
        FROM example.table
        WHERE bazz >= '{}'
    """.format(query_param)

Here, we have a SQL query that can be formatted with parameters supplied to the main() method.

Next we have to make the query by submitting it with spark:

def _run_query(fomatted_sql_query: str):
    """Run sql query and return spark dataframe."""
    return pyspark.sql(fomatted_sql_query)

This step will be patched in unit tests to avoid actually querying the database. In its place, a fixture representing a subset of data that matches the database schema will be supplied instead.

Now we will provide two example transformation on the DataFrame. What these do is not super important - these steps are purely for demonstration.

# first example transformation step
def _transformation_example_one(df):
    return df.withColumn(
        "joined",
        sf.concat(
            sf.col("foo"),
            sf.lit("_"),
            sf.col("bar")))


# second example transformation step
def _transformation_example_two(df):
    return df.withColumn(
        "factored",
        sf.col("foo") * sf.col("bar")
    )

Fleshing out the main method

We can now revisit the main() method and show how all the example steps can be rolled together in a main() method workflow:

def main(query_param: str, save_location: str):
    # format and run query
    sql_query = _create_query_str(query_param)
    spark_df = _run_query(sql_query)

    # apply a series of operations
    spark_df = _transformation_example_one(spark_df)
    spark_df = _transformation_example_two(spark_df)

    # save/write operation
    (
        spark_df
        .repartition("partition")
        .write
            .partitionBy("partition")
            .mode("overwrite")
            .format("json")
            .option("compression", "gzip")
        .save(save_location)
    )

The save/write operation could probably also be broken out into a different step, too. But, I’ll leave at this for the sake of the example.

Testing overview

I will do two main tests. First, I will have tests that check the individual steps and make sure they behave as expected. Then I will have tests that check that each step integrates with its subsequent steps by running the whole main method and checking its results.

Testing steps

First, to check each step we can start by creating pytest fixtures that mock inputs and outputs from each step of our multi-step main method.

In this case, I create the following two:

from unittest.mock import patch


@pytest.fixture
def query_results_fixture_df():
    return spark.read.json("test/fixtures/query_results_sample.json")


@pytest.fixture
def query_results_fixture_stage_two_df():
    return spark.read.json("test/fixtures/query_results_sample_stage_two.json")

We can see how these are used by examining all the tests for the steps:

class TestUnitMethods:
    def test_create_query_str(self):
        assert "WHERE bazz >= 'abc'\n" in _create_query_str('abc')

    def test_transformation_example_one(self, query_results_fixture_df):
        res_df = _transformation_example_one(query_results_fixture_df)

        # run tests specific to the output state at this stage
        got_cols = set(res_df.columns)
        assert got_cols == set(["foo", "bar", "partition", "joined"])

        rpdf = res_df.toPandas()
        assert set(rpdf["joined"]) == set(["100_1","130_2","302_3","293_4","173_5","462_6"])

    def test_transformation_example_two(self, query_results_fixture_stage_two_df):
        res_df = _transformation_example_two(query_results_fixture_stage_two_df)

        # run tests specific to the output state at this stage
        got_cols = set(res_df.columns)
        assert got_cols == set(["foo", "bar", "partition", "joined", "factored"])

        rpdf = res_df.toPandas()
        assert set(rpdf["factored"]) == set([100,260,906,1172,865,2772])

What has been done is that, each step has been isolated with a mock for the input and a mock for the output having been shimmed. Then, we can compare the result produced with the result we expected by comparing the output fixture against the one generated.

In the case of the multiple steps, the output of one step can also be recycled to be the reference DataFrame for the subsequent step.

Testing the main method

Now we can move on to test the whole process combined in the main function. In this case, we can also test the write step since it’s an “output” of the main method, essentially.

class TestMainMethod:
    @patch("path.to.the._run_query")
    def test_integration(self, _run_query, query_results_fixture_df):
        # patch call to pyspark.sql to avoid actually submitting sql query
        _run_query.return_value = query_results_fixture_df

        # execute the whole main function and thus run all steps together
        temp_save_loc = "temp_test_spark_write_output_dir"
        query_param = "fizzbuzz"
        main(query_param, temp_save_loc)

        # TODO: load in output file from temp_save_loc and compare to expected

        # cleanup results
        shutil.rmtree(temp_save_loc)

Now, by parameterizing the write location, we can avoid writing to an external service like S3 and instead write to a temporary directory locally. See the “TODO” section for where that file could then be read in and compared to an example dataset to ensure that the output data produced from the main() method matched with what is expected.

Conclusion

I’ve noticed that unit testing may not be as “popular” with Spark applications because the set up is onerous but, with some careful method structuring, unit tests can be developed to ensure that transformation steps behave as expected. In addition, such steps help enforce that each method can safely expect certain columns and data presence and make future modifications to data transformations performed more safely.