Skip to content

Commit

Permalink
cache contents of test data file
Browse files Browse the repository at this point in the history
  • Loading branch information
tnixon committed Nov 5, 2024
1 parent 6219879 commit 2f00ced
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
import warnings
from typing import Union, Optional
from functools import cached_property

import jsonref
import pyspark.sql.functions as sfn
Expand Down Expand Up @@ -170,7 +171,10 @@ class SparkTest(unittest.TestCase):

# Spark Session object
spark = None
test_data = None

# test data
test_data_file = None
test_case_data = None

@classmethod
def setUpClass(cls) -> None:
Expand Down Expand Up @@ -208,18 +212,18 @@ def tearDownClass(cls) -> None:
cls.spark.stop()

def setUp(self) -> None:
self.test_data = self.__loadTestData(self.id())
self.test_case_data = self.__loadTestData(self.id())

def tearDown(self) -> None:
del self.test_data
del self.test_case_data

#
# Utility Functions
#

def get_data_as_idf(self, name: str, convert_ts_col=True):
df = self.get_data_as_sdf(name, convert_ts_col)
td = self.test_data[name]
td = self.test_case_data[name]
idf = IntervalsDF(
df,
start_ts=td["start_ts"],
Expand Down Expand Up @@ -258,20 +262,28 @@ def __loadTestData(self, test_case_path: str) -> dict:
"""
file_name, class_name, func_name = test_case_path.split(".")[-3:]

# find our test data file
test_data_file = self.__getTestDataFilePath(file_name)
if not os.path.isfile(test_data_file):
warnings.warn(f"Could not load test data file {test_data_file}")
return {}
# load the test data file if it hasn't been loaded yet
if self.test_data_file is None:
# find our test data file
test_data_filename = self.__getTestDataFilePath(file_name)
if not os.path.isfile(test_data_filename):
warnings.warn(f"Could not load test data file {test_data_filename}")
self.test_data_file = {}

# proces the data file
with open(test_data_filename, "r") as f:
self.test_data_file = jsonref.load(f)

# return the data if it exists
if class_name in self.test_data_file:
if func_name in self.test_data_file[class_name]:
return self.test_data_file[class_name][func_name]

# proces the data file
with open(test_data_file, "r") as f:
data_metadata_from_json = jsonref.load(f)
# return the data
return data_metadata_from_json[class_name][func_name]
# return empty dictionary if no data found
return {}

def get_test_df_builder(self, name: str) -> TestDataFrameBuilder:
return TestDataFrameBuilder(self.spark, self.test_data[name])
return TestDataFrameBuilder(self.spark, self.test_case_data[name])

#
# Assertion Functions
Expand Down

0 comments on commit 2f00ced

Please sign in to comment.