From 26cbb7f8bc8e5f5f523adb2a88c1911f91bf8d71 Mon Sep 17 00:00:00 2001 From: sllynn Date: Wed, 13 Nov 2024 16:50:44 +0000 Subject: [PATCH] updated the logic of some of the statistical raster functions to work with small clipping outputs --- .../raster/gdal/MosaicRasterBandGDAL.scala | 16 ++++++-- .../raster/RST_MedianBehaviors.scala | 37 ++++++++++++++++++- .../expressions/raster/RST_MedianTest.scala | 11 +++++- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index ffd32e109..7f73802a8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -148,9 +148,19 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { */ def computeMinMax: Seq[Double] = { val minMaxVals = Array.fill[Double](2)(0) - Try(band.ComputeRasterMinMax(minMaxVals, 0)) - .map(_ => minMaxVals.toSeq) - .getOrElse(Seq(Double.NaN, Double.NaN)) + // will GDAL refuse to compute these stats? + if (band.GetXSize() == 1 || band.GetYSize() == 1) { + val validPixels = values.filter(_ != noDataValue) + if (validPixels.isEmpty) { + return Seq(Double.NaN, Double.NaN) + } else { + Seq(validPixels.min, validPixels.max) + } + } else { + Try(band.ComputeRasterMinMax(minMaxVals, 0)) + .map(_ => minMaxVals.toSeq) + .getOrElse(Seq(Double.NaN, Double.NaN)) + } } /** diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala index 1b99fbc6f..89608568d 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala @@ -9,7 +9,8 @@ import org.scalatest.matchers.should.Matchers._ trait RST_MedianBehaviors extends QueryTest { - def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + def largeAreaBehavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) mc.register() val sc = spark @@ -49,4 +50,38 @@ trait RST_MedianBehaviors extends QueryTest { } + def smallAreaBehavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val testRegion = "Polygon ((-8151524 1216659, -8151061 1216659, -8151061 1217123, -8151524 1217123, -8151524 1216659))" + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + rastersInMemory.createOrReplaceTempView("source") + + val df = rastersInMemory + .withColumn("tile", rst_clip($"tile", st_buffer(lit(testRegion), lit(-20)))) + .withColumn("result", rst_median($"tile")) + .select("result") + .select(explode($"result").as("result")) + + noException should be thrownBy spark.sql(""" + |select rst_median(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().max + + result should be > 0.0 + + } + + } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala index cfe270813..a1668ca3c 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala @@ -22,11 +22,18 @@ class RST_MedianTest extends QueryTest with SharedSparkSessionGDAL with RST_Medi // These tests are not index system nor geometry API specific. // Only testing one pairing is sufficient. - test("Testing rst_median behavior with H3IndexSystem and JTS") { + test("Testing rst_median behavior with H3IndexSystem and JTS (tessellation case)") { noCodegen { assume(System.getProperty("os.name") == "Linux") - behavior(H3IndexSystem, JTS) + largeAreaBehavior(H3IndexSystem, JTS) } } + test("Testing rst_median behavior with H3IndexSystem and JTS (small area case)") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + smallAreaBehavior(H3IndexSystem, JTS) + } + } + }