Skip to content

Commit

Permalink
update python module and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sllynn committed Nov 27, 2023
1 parent 015965b commit 2e54548
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
9 changes: 7 additions & 2 deletions python/mosaic/api/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def rst_combineavg_agg(raster: ColumnOrName) -> Column:
)


def rst_derivedband_agg(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName) -> Column:
def rst_derivedband_agg(
raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName
) -> Column:
"""
Returns the raster representing the aggregation of rasters using provided python function.
Expand All @@ -228,5 +230,8 @@ def rst_derivedband_agg(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName
The resulting raster.
"""
return config.mosaic_context.invoke_function(
"rst_derivedband_agg", pyspark_to_java_column(raster), pyspark_to_java_column(pythonFunc), pyspark_to_java_column(funcName)
"rst_derivedband_agg",
pyspark_to_java_column(raster),
pyspark_to_java_column(pythonFunc),
pyspark_to_java_column(funcName),
)
7 changes: 4 additions & 3 deletions python/mosaic/api/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ def rst_combineavg(rasters: ColumnOrName) -> Column:
)


def rst_derivedband(raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName) -> Column:
def rst_derivedband(
raster: ColumnOrName, pythonFunc: ColumnOrName, funcName: ColumnOrName
) -> Column:
"""
Creates a new band by applying the given python function to the input rasters.
The result is a raster tile.
Expand Down Expand Up @@ -867,8 +869,7 @@ def rst_setsrid(raster: ColumnOrName, srid: ColumnOrName) -> Column:
"""
return config.mosaic_context.invoke_function(
"rst_setsrid", pyspark_to_java_column(raster),
pyspark_to_java_column(srid)
"rst_setsrid", pyspark_to_java_column(raster), pyspark_to_java_column(srid)
)


Expand Down
8 changes: 3 additions & 5 deletions python/test/test_raster_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_raster_flatmap_functions(self):
)

tessellate_result.write.format("noop").mode("overwrite").save()
self.assertEqual(tessellate_result.count(), 55)
self.assertEqual(tessellate_result.count(), 140)

overlap_result = self.generate_singleband_raster_df().withColumn(
"rst_to_overlapping_tiles",
Expand Down Expand Up @@ -169,10 +169,8 @@ def test_netcdf(self):

self.assertEqual(df.count(), 31)

grid_tiles = (
df
.withColumn("tile", api.rst_setsrid("tile", lit(4326)))
.select(api.rst_tessellate("tile", lit(3)).alias("tile"))
grid_tiles = df.withColumn("tile", api.rst_setsrid("tile", lit(4326))).select(
api.rst_tessellate("tile", lit(3)).alias("tile")
)

self.assertEqual(grid_tiles.count(), 4495)

0 comments on commit 2e54548

Please sign in to comment.