From 2e54548adcc9250bbad069a72d69b31eef1af40e Mon Sep 17 00:00:00 2001 From: Stuart Lynn Date: Mon, 27 Nov 2023 09:48:16 +0000 Subject: [PATCH] update python module and tests --- python/mosaic/api/aggregators.py | 9 +++++++-- python/mosaic/api/raster.py | 7 ++++--- python/test/test_raster_functions.py | 8 +++----- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/python/mosaic/api/aggregators.py b/python/mosaic/api/aggregators.py index e221d06ba..696893882 100644 --- a/python/mosaic/api/aggregators.py +++ b/python/mosaic/api/aggregators.py @@ -212,7 +212,9 @@ def rst_combineavg_agg(raster: ColumnOrName) -> Column: ) -def rst_derivedband_agg(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName) -> Column: +def rst_derivedband_agg( + raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName +) -> Column: """ Returns the raster representing the aggregation of rasters using provided python function. @@ -228,5 +230,8 @@ def rst_derivedband_agg(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName The resulting raster. """ return config.mosaic_context.invoke_function( - "rst_derivedband_agg", pyspark_to_java_column(raster), pyspark_to_java_column(pythonFunc), pyspark_to_java_column(funcName) + "rst_derivedband_agg", + pyspark_to_java_column(raster), + pyspark_to_java_column(pythonFunc), + pyspark_to_java_column(funcName), ) diff --git a/python/mosaic/api/raster.py b/python/mosaic/api/raster.py index caa4cde23..d72c7b157 100644 --- a/python/mosaic/api/raster.py +++ b/python/mosaic/api/raster.py @@ -148,7 +148,9 @@ def rst_combineavg(rasters: ColumnOrName) -> Column: ) -def rst_derivedband(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName) -> Column: +def rst_derivedband( + raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName +) -> Column: """ Creates a new band by applying the given python function to the input rasters. The result is a raster tile. @@ -867,8 +869,7 @@ def rst_setsrid(raster: ColumnOrName, srid: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_setsrid", pyspark_to_java_column(raster), - pyspark_to_java_column(srid) + "rst_setsrid", pyspark_to_java_column(raster), pyspark_to_java_column(srid) ) diff --git a/python/test/test_raster_functions.py b/python/test/test_raster_functions.py index ccf0476a8..04d1a4de2 100644 --- a/python/test/test_raster_functions.py +++ b/python/test/test_raster_functions.py @@ -116,7 +116,7 @@ def test_raster_flatmap_functions(self): ) tessellate_result.write.format("noop").mode("overwrite").save() - self.assertEqual(tessellate_result.count(), 55) + self.assertEqual(tessellate_result.count(), 140) overlap_result = self.generate_singleband_raster_df().withColumn( "rst_to_overlapping_tiles", @@ -169,10 +169,8 @@ def test_netcdf(self): self.assertEqual(df.count(), 31) - grid_tiles = ( - df - .withColumn("tile", api.rst_setsrid("tile", lit(4326))) - .select(api.rst_tessellate("tile", lit(3)).alias("tile")) + grid_tiles = df.withColumn("tile", api.rst_setsrid("tile", lit(4326))).select( + api.rst_tessellate("tile", lit(3)).alias("tile") ) self.assertEqual(grid_tiles.count(), 4495)