diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs index 2032ae62d..fc15facb4 100644 --- a/core/src/execution/datafusion/shuffle_writer.rs +++ b/core/src/execution/datafusion/shuffle_writer.rs @@ -315,6 +315,9 @@ fn slot_size(len: usize, data_type: &DataType) -> usize { // TODO: this is not accurate, but should be good enough for now slot_size(len, key_type.as_ref()) + slot_size(len / 10, value_type.as_ref()) } + // TODO: this is not accurate, but should be good enough for now + DataType::Binary => len * 100 + len * 4, + DataType::LargeBinary => len * 100 + len * 8, DataType::FixedSizeBinary(s) => len * (*s as usize), DataType::Timestamp(_, _) => len * 8, dt => unimplemented!( @@ -521,6 +524,8 @@ fn append_columns( { append_string_dict!(key_type) } + DataType::Binary => append!(Binary), + DataType::LargeBinary => append!(LargeBinary), DataType::FixedSizeBinary(_) => append_unwrap!(FixedSizeBinary), t => unimplemented!( "{}", @@ -1275,3 +1280,36 @@ impl RecordBatchStream for EmptyStream { self.schema.clone() } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_slot_size() { + let batch_size = 1usize; + // not inclusive of all supported types, but enough to test the function + let supported_primitive_types = [ + DataType::Int32, + DataType::Int64, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + DataType::Boolean, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + DataType::FixedSizeBinary(16), + ]; + let expected_slot_size = [4, 8, 4, 8, 4, 8, 1, 104, 108, 104, 108, 16]; + supported_primitive_types + .iter() + .zip(expected_slot_size.iter()) + .for_each(|(data_type, expected)| { + let slot_size = slot_size(batch_size, data_type); + assert_eq!(slot_size, *expected); + }) + } +} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala index beb6dc860..acd424acd 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometShuffleSuite.scala @@ -1169,6 +1169,22 @@ abstract class CometShuffleSuiteBase extends CometTestBase with AdaptiveSparkPla } } + test("fix: comet native shuffle with binary data") { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "false") { + withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl") + + val shuffled = df.repartition(1, $"binary") + + checkCometExchange(shuffled, 1, true) + checkSparkAnswer(shuffled) + } + } + } + test("Comet shuffle metrics") { withSQLConf( CometConf.COMET_EXEC_ENABLED.key -> "true",