diff --git a/native/core/src/execution/shuffle/row.rs b/native/core/src/execution/shuffle/row.rs index f9b201944..405f64216 100644 --- a/native/core/src/execution/shuffle/row.rs +++ b/native/core/src/execution/shuffle/row.rs @@ -3363,9 +3363,6 @@ pub fn process_sorted_row_partition( let codec = CompressionCodec::Zstd(1); written += write_ipc_compressed(&batch, &mut cursor, &codec, &ipc_time)?; - // TODO document this more - this is important - written += 8; - if let Some(checksum) = &mut current_checksum { checksum.update(&mut cursor)?; } diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index f17d720ef..aae71abc4 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -1601,7 +1601,7 @@ pub fn write_ipc_compressed( timer.stop(); - Ok(compressed_length as usize) + Ok((end_pos - start_pos) as usize) } pub fn read_ipc_compressed(bytes: &[u8]) -> Result { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala index 0c44d6215..9680e5f70 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/CometBlockStoreShuffleReader.scala @@ -80,7 +80,7 @@ class CometBlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - var currentReadIterator: ShuffleBatchDecoderIterator = null + var currentReadIterator: NativeBatchDecoderIterator = null // Closes last read iterator after the task is finished. // We need to close read iterator during iterating input streams, @@ -92,14 +92,13 @@ class CometBlockStoreShuffleReader[K, C]( } } - var currentDecoder: ShuffleBatchDecoderIterator = null val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator .flatMap(blockIdAndStream => { - if (currentDecoder != null) { - currentDecoder.close() + if (currentReadIterator != null) { + currentReadIterator.close() } - currentDecoder = ShuffleBatchDecoderIterator(blockIdAndStream._2, context) - currentDecoder + currentReadIterator = NativeBatchDecoderIterator(blockIdAndStream._2, context) + currentReadIterator }) .map(b => (0, b)) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/NativeBatchDecoderIterator.scala similarity index 98% rename from spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala rename to spark/src/main/scala/org/apache/spark/sql/comet/shuffle/NativeBatchDecoderIterator.scala index a7b4018e5..62c9b0c20 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/NativeBatchDecoderIterator.scala @@ -35,7 +35,7 @@ import org.apache.comet.vector.NativeUtil * native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks * and use Arrow FFI to return the Arrow record batch. */ -case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskContext) +case class NativeBatchDecoderIterator(var in: InputStream, taskContext: TaskContext) extends Iterator[ColumnarBatch] { private val SPARK_LZ4_MAGIC = Array[Byte](76, 90, 52, 66, 108, 111, 99, 107) // "LZ4Block" private var nextBatch: Option[ColumnarBatch] = None