diff --git a/core/src/execution/shuffle/row.rs b/core/src/execution/shuffle/row.rs index 36a7b2424..d3efc16dc 100644 --- a/core/src/execution/shuffle/row.rs +++ b/core/src/execution/shuffle/row.rs @@ -1758,7 +1758,7 @@ pub(crate) fn append_columns( } _ => { return Err(CometError::Internal(format!( - "Unsupported type: {:?}", + "Unsupported map type: {:?}", field.data_type() ))) } @@ -3182,7 +3182,7 @@ fn make_builders( _ => { return Err(CometError::Internal(format!( - "Unsupported type: {:?}", + "Unsupported map type: {:?}", field.data_type() ))) } @@ -3255,7 +3255,7 @@ fn make_builders( } _ => { return Err(CometError::Internal(format!( - "Unsupported type: {:?}", + "Unsupported list type: {:?}", field.data_type() ))) } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 0a251d448..a365e7543 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1830,10 +1830,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case StructType(fields) => fields.forall(f => supportedDataType(f.dataType)) case ArrayType(ArrayType(_, _), _) => false // TODO: nested array is not supported + case ArrayType(MapType(_, _, _), _) => false // TODO: map array element is not supported case ArrayType(elementType, _) => supportedDataType(elementType) case MapType(MapType(_, _, _), _, _) => false // TODO: nested map is not supported case MapType(_, MapType(_, _, _), _) => false + case MapType(StructType(_), _, _) => false // TODO: struct map key/value is not supported + case MapType(_, StructType(_), _) => false + case MapType(ArrayType(_, _), _, _) => false // TODO: array map key/value is not supported + case MapType(_, ArrayType(_, _), _) => false case MapType(keyType, valueType, _) => supportedDataType(keyType) && supportedDataType(valueType) case _ => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala index ac5664911..beb6dc860 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala @@ -112,6 +112,91 @@ abstract class CometShuffleSuiteBase extends CometTestBase with AdaptiveSparkPla } } + test("columnar shuffle on array/struct map key/value") { + Seq("false", "true").foreach { execEnabled => + Seq(10, 201).foreach { numPartitions => + Seq("1.0", "10.0").foreach { ratio => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> execEnabled, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withParquetTable((0 until 50).map(i => (Map(Seq(i, i + 1) -> i), i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_2") + + checkSparkAnswer(df) + // Array map key array element fallback to Spark shuffle for now + checkCometExchange(df, 0, false) + } + + withParquetTable((0 until 50).map(i => (Map(i -> Seq(i, i + 1)), i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_2") + + checkSparkAnswer(df) + // Array map value array element fallback to Spark shuffle for now + checkCometExchange(df, 0, false) + } + + withParquetTable((0 until 50).map(i => (Map((i, i.toString) -> i), i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_2") + + checkSparkAnswer(df) + // Struct map key array element fallback to Spark shuffle for now + checkCometExchange(df, 0, false) + } + + withParquetTable((0 until 50).map(i => (Map(i -> (i, i.toString)), i + 1)), "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_2") + + checkSparkAnswer(df) + // Struct map value array element fallback to Spark shuffle for now + checkCometExchange(df, 0, false) + } + } + } + } + } + } + + test("columnar shuffle on map array element") { + Seq("false", "true").foreach { execEnabled => + Seq(10, 201).foreach { numPartitions => + Seq("1.0", "10.0").foreach { ratio => + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> execEnabled, + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withParquetTable( + (0 until 50).map(i => ((Seq(Map(1 -> i)), Map(2 -> i), Map(3 -> i)), i + 1)), + "tbl") { + val df = sql("SELECT * FROM tbl") + .filter($"_2" > 10) + .repartition(numPartitions, $"_1", $"_2") + .sortWithinPartitions($"_2") + + checkSparkAnswer(df) + // Map array element fallback to Spark shuffle for now + checkCometExchange(df, 0, false) + } + } + } + } + } + } + test("RoundRobinPartitioning is supported by columnar shuffle") { withSQLConf( // AQE has `ShuffleStage` which is a leaf node which blocks