From 2f00ced5a4848543738b292d1ff5b25c8c83c225 Mon Sep 17 00:00:00 2001 From: Tristan Nixon Date: Mon, 4 Nov 2024 21:40:14 -0800 Subject: [PATCH] cache contents of test data file --- python/tests/base.py | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/python/tests/base.py b/python/tests/base.py index 8538a1ce..3d08b8f9 100644 --- a/python/tests/base.py +++ b/python/tests/base.py @@ -2,6 +2,7 @@ import unittest import warnings from typing import Union, Optional +from functools import cached_property import jsonref import pyspark.sql.functions as sfn @@ -170,7 +171,10 @@ class SparkTest(unittest.TestCase): # Spark Session object spark = None - test_data = None + + # test data + test_data_file = None + test_case_data = None @classmethod def setUpClass(cls) -> None: @@ -208,10 +212,10 @@ def tearDownClass(cls) -> None: cls.spark.stop() def setUp(self) -> None: - self.test_data = self.__loadTestData(self.id()) + self.test_case_data = self.__loadTestData(self.id()) def tearDown(self) -> None: - del self.test_data + del self.test_case_data # # Utility Functions @@ -219,7 +223,7 @@ def tearDown(self) -> None: def get_data_as_idf(self, name: str, convert_ts_col=True): df = self.get_data_as_sdf(name, convert_ts_col) - td = self.test_data[name] + td = self.test_case_data[name] idf = IntervalsDF( df, start_ts=td["start_ts"], @@ -258,20 +262,28 @@ def __loadTestData(self, test_case_path: str) -> dict: """ file_name, class_name, func_name = test_case_path.split(".")[-3:] - # find our test data file - test_data_file = self.__getTestDataFilePath(file_name) - if not os.path.isfile(test_data_file): - warnings.warn(f"Could not load test data file {test_data_file}") - return {} + # load the test data file if it hasn't been loaded yet + if self.test_data_file is None: + # find our test data file + test_data_filename = self.__getTestDataFilePath(file_name) + if not os.path.isfile(test_data_filename): + warnings.warn(f"Could not load test data file {test_data_filename}") + self.test_data_file = {} + + # proces the data file + with open(test_data_filename, "r") as f: + self.test_data_file = jsonref.load(f) + + # return the data if it exists + if class_name in self.test_data_file: + if func_name in self.test_data_file[class_name]: + return self.test_data_file[class_name][func_name] - # proces the data file - with open(test_data_file, "r") as f: - data_metadata_from_json = jsonref.load(f) - # return the data - return data_metadata_from_json[class_name][func_name] + # return empty dictionary if no data found + return {} def get_test_df_builder(self, name: str) -> TestDataFrameBuilder: - return TestDataFrameBuilder(self.spark, self.test_data[name]) + return TestDataFrameBuilder(self.spark, self.test_case_data[name]) # # Assertion Functions