From 824ce7d9a6e051cae24063074114f0a01d81c974 Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Sat, 24 Feb 2024 20:34:47 +0800 Subject: [PATCH] fix: Add num_rows when building RecordBatch --- core/src/execution/datafusion/shuffle_writer.rs | 7 +++++-- core/src/execution/shuffle/row.rs | 9 +++++---- .../scala/org/apache/comet/exec/CometExecSuite.scala | 1 + 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/core/src/execution/datafusion/shuffle_writer.rs b/core/src/execution/datafusion/shuffle_writer.rs index 2e17dcf0e..2032ae62d 100644 --- a/core/src/execution/datafusion/shuffle_writer.rs +++ b/core/src/execution/datafusion/shuffle_writer.rs @@ -272,10 +272,11 @@ impl PartitionBuffer { // active -> staging let active = std::mem::take(&mut self.active); + let num_rows = self.num_active_rows; self.num_active_rows = 0; mem_diff -= self.active_slots_mem_size as isize; - let frozen_batch = make_batch(self.schema.clone(), active)?; + let frozen_batch = make_batch(self.schema.clone(), active, num_rows)?; let frozen_capacity_old = self.frozen.capacity(); let mut cursor = Cursor::new(&mut self.frozen); @@ -1148,9 +1149,11 @@ fn make_dict_builder(datatype: &DataType, capacity: usize) -> Box>, + row_count: usize, ) -> ArrowResult { let columns = arrays.iter_mut().map(|array| array.finish()).collect(); - RecordBatch::try_new(schema, columns) + let options = RecordBatchOptions::new().with_row_count(Option::from(row_count)); + RecordBatch::try_new_with_options(schema, columns, &options) } /// Checksum algorithms for writing IPC bytes. diff --git a/core/src/execution/shuffle/row.rs b/core/src/execution/shuffle/row.rs index e24fbbee1..419ef9b4b 100644 --- a/core/src/execution/shuffle/row.rs +++ b/core/src/execution/shuffle/row.rs @@ -37,7 +37,7 @@ use arrow_array::{ StructBuilder, TimestampMicrosecondBuilder, }, types::Int32Type, - Array, ArrayRef, RecordBatch, + Array, ArrayRef, RecordBatch, RecordBatchOptions, }; use arrow_schema::{DataType, Field, Schema, TimeUnit}; use jni::sys::{jint, jlong}; @@ -3347,7 +3347,7 @@ pub fn process_sorted_row_partition( .zip(schema.iter()) .map(|(builder, datatype)| builder_to_array(builder, datatype, prefer_dictionary_ratio)) .collect(); - let batch = make_batch(array_refs?); + let batch = make_batch(array_refs?, n); let mut frozen: Vec = vec![]; let mut cursor = Cursor::new(&mut frozen); @@ -3420,7 +3420,7 @@ fn builder_to_array( } } -fn make_batch(arrays: Vec) -> RecordBatch { +fn make_batch(arrays: Vec, row_count: usize) -> RecordBatch { let mut dict_id = 0; let fields = arrays .iter() @@ -3441,5 +3441,6 @@ fn make_batch(arrays: Vec) -> RecordBatch { }) .collect::>(); let schema = Arc::new(Schema::new(fields)); - RecordBatch::try_new(schema, arrays).unwrap() + let options = RecordBatchOptions::new().with_row_count(Option::from(row_count)); + RecordBatch::try_new_with_options(schema, arrays, &options).unwrap() } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index eb5a8e9eb..0b94f0a5c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -373,6 +373,7 @@ class CometExecSuite extends CometTestBase { withParquetDataFrame((0 until 5).map(i => (i, i + 1))) { df => assert(df.where("_1 IS NOT NULL").count() == 5) checkSparkAnswerAndOperator(df) + assert(df.select().limit(2).count() === 2) } }