From 500895dfa3319e2da276218068d27372f26ea931 Mon Sep 17 00:00:00 2001 From: Himadri Pal Date: Fri, 22 Nov 2024 05:52:45 -0800 Subject: [PATCH] feat: enable decimal to decimal cast of different precision and scale (#1086) * enable decimal to decimal cast of different precision and scale * add more test cases for negative scale and higher precision * add check for compatibility for decimal to decimal * fix code style * Update spark/src/main/scala/org/apache/comet/expressions/CometCast.scala Co-authored-by: Andy Grove * fix the nit in comment --------- Co-authored-by: himadripal Co-authored-by: Andy Grove --- .../apache/comet/expressions/CometCast.scala | 10 +++++-- .../org/apache/comet/CometCastSuite.scala | 28 +++++++++++++++++++ .../org/apache/spark/sql/CometTestBase.scala | 6 ++-- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 11d6d049f..859cb13be 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -70,9 +70,13 @@ object CometCast { case _ => Unsupported } - case (_: DecimalType, _: DecimalType) => - // https://github.com/apache/datafusion-comet/issues/375 - Incompatible() + case (from: DecimalType, to: DecimalType) => + if (to.precision < from.precision) { + // https://github.com/apache/datafusion/issues/13492 + Incompatible(Some("Casting to smaller precision is not supported")) + } else { + Compatible() + } case (DataTypes.StringType, _) => canCastFromString(toType, timeZoneId, evalMode) case (_, DataTypes.StringType) => diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 01583077d..f8c1a8b09 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -892,6 +892,34 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast between decimals with different precision and scale") { + // cast between default Decimal(38, 18) to Decimal(6,2) + val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567")) + val df = withNulls(values) + .toDF("b") + .withColumn("a", col("b").cast(DecimalType(6, 2))) + checkSparkAnswer(df) + } + + test("cast between decimals with higher precision than source") { + // cast between Decimal(10, 2) to Decimal(10,4) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4)) + } + + test("cast between decimals with negative precision") { + // cast to negative scale + checkSparkMaybeThrows( + spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) match { + case (expected, actual) => + assert(expected.contains("PARSE_SYNTAX_ERROR") === actual.contains("PARSE_SYNTAX_ERROR")) + } + } + + test("cast between decimals with zero precision") { + // cast between Decimal(10, 2) to Decimal(10,0) + castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0)) + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1709cce61..213ec7efe 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -231,11 +231,9 @@ abstract class CometTestBase df: => DataFrame): (Option[Throwable], Option[Throwable]) = { var expected: Option[Throwable] = None withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val dfSpark = Dataset.ofRows(spark, df.logicalPlan) - expected = Try(dfSpark.collect()).failed.toOption + expected = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption } - val dfComet = Dataset.ofRows(spark, df.logicalPlan) - val actual = Try(dfComet.collect()).failed.toOption + val actual = Try(Dataset.ofRows(spark, df.logicalPlan).collect()).failed.toOption (expected, actual) }