From f218870c07bb78097992046920bad14d2d5c740a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 21 Dec 2024 09:07:00 -0700 Subject: [PATCH] reuse buffer --- native/core/src/execution/jni_api.rs | 22 +++++-------- native/core/src/execution/shuffle/mod.rs | 2 +- .../src/execution/shuffle/shuffle_writer.rs | 4 +-- .../shuffle/ShuffleBatchDecoderIterator.scala | 33 ++++++++++++++----- 4 files changed, 36 insertions(+), 25 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e428fcff2..a90a91d2f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -59,7 +59,7 @@ use jni::{ use tokio::runtime::Runtime; use crate::execution::operators::ScanExec; -use crate::execution::shuffle::read_ipc_compressed_zstd; +use crate::execution::shuffle::read_ipc_compressed; use crate::execution::spark_plan::SparkPlan; use log::info; @@ -95,7 +95,7 @@ struct ExecutionContext { /// Accept serialized query plan and return the address of the native query plan. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( e: JNIEnv, @@ -295,7 +295,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr /// Accept serialized query plan and the addresses of Arrow Arrays from Spark, /// then execute the query. Return addresses of arrow vector. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( e: JNIEnv, @@ -458,7 +458,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { /// Used by Comet shuffle external sorter to write sorted records to disk. /// # Safety -/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( e: JNIEnv, @@ -546,7 +546,9 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( #[no_mangle] /// Used by Comet native shuffle reader -pub extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( +/// # Safety +/// This function is inherently unsafe since it deals with raw pointers passed from JNI. +pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( e: JNIEnv, _class: JClass, byte_array: jbyteArray, @@ -559,13 +561,7 @@ pub extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( let elements = unsafe { env.get_array_elements(&value_array, ReleaseMode::NoCopyBack)? }; let raw_pointer = elements.as_ptr(); let slice = unsafe { std::slice::from_raw_parts(raw_pointer, length as usize) }; - let batch = read_ipc_compressed_zstd(slice)?; - Ok(prepare_output( - &mut env, - array_addrs, - schema_addrs, - batch, - false, - )?) + let batch = read_ipc_compressed(slice)?; + prepare_output(&mut env, array_addrs, schema_addrs, batch, false) }) } diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index b3080d14c..178aff1fa 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -20,5 +20,5 @@ mod map; pub mod row; mod shuffle_writer; pub use shuffle_writer::{ - read_ipc_compressed_zstd, write_ipc_compressed, CompressionCodec, ShuffleWriterExec, + read_ipc_compressed, write_ipc_compressed, CompressionCodec, ShuffleWriterExec, }; diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index cb8f0c628..24e5036ea 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -1604,7 +1604,7 @@ pub fn write_ipc_compressed( Ok(compressed_length as usize) } -pub fn read_ipc_compressed_zstd(bytes: &[u8]) -> Result { +pub fn read_ipc_compressed(bytes: &[u8]) -> Result { let decoder = zstd::Decoder::new(bytes)?; let mut reader = StreamReader::try_new(decoder, None)?; // TODO check for None @@ -1675,7 +1675,7 @@ mod test { assert_eq!(40210, length); let ipc_without_length_prefix = &output[16..]; - let batch2 = read_ipc_compressed_zstd(ipc_without_length_prefix).unwrap(); + let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); assert_eq!(batch, batch2); } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala index 3bb589ef6..00b4ec90a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/shuffle/ShuffleBatchDecoderIterator.scala @@ -22,18 +22,25 @@ package org.apache.spark.sql.comet.shuffle import java.io.{EOFException, InputStream} import java.nio.ByteBuffer import java.nio.ByteOrder +import java.nio.channels.{Channels, ReadableByteChannel} + import org.apache.spark.TaskContext import org.apache.spark.sql.vectorized.ColumnarBatch + import org.apache.comet.Native import org.apache.comet.vector.NativeUtil -import java.nio.channels.{Channels, ReadableByteChannel} - +/** + * This iterator wraps a Spark input stream that is reading shuffle blocks generated by the Comet + * native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks + * and use Arrow FFI to return the Arrow record batch. + */ case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskContext) extends Iterator[ColumnarBatch] { private var nextBatch: Option[ColumnarBatch] = None private var finished = false; private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + private var messageBuf = new Array[Byte](8192 * 512) private val native = new Native() private val nativeUtil = new NativeUtil() @@ -82,15 +89,14 @@ case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskCon val fieldCount = longBuf.getLong.toInt // read body - // TODO reuse buffer - val buffer = new Array[Byte](compressedLength.toInt) - fillBuffer(in, buffer) + ensureCapacity(compressedLength) + fillBuffer(in, messageBuf, compressedLength) // make native call to decode batch nextBatch = nativeUtil.getNextBatch( fieldCount, (arrayAddrs, schemaAddrs) => { - native.decodeShuffleBlock(buffer, arrayAddrs, schemaAddrs) + native.decodeShuffleBlock(messageBuf, arrayAddrs, schemaAddrs) }) true @@ -106,10 +112,19 @@ case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskCon } } - private def fillBuffer(in: InputStream, buffer: Array[Byte]): Unit = { + private def ensureCapacity(requiredSize: Int): Unit = { + if (messageBuf.length < requiredSize) { + val newSize = Math.max(messageBuf.length * 2, requiredSize) + val newBuffer = new Array[Byte](newSize) + Array.copy(messageBuf, 0, newBuffer, 0, messageBuf.length) + messageBuf = newBuffer + } + } + + private def fillBuffer(in: InputStream, buffer: Array[Byte], len: Int): Unit = { var bytesRead = 0 - while (bytesRead < buffer.length) { - val result = in.read(buffer, bytesRead, buffer.length - bytesRead) + while (bytesRead < len) { + val result = in.read(buffer, bytesRead, len - bytesRead) if (result == -1) throw new EOFException() bytesRead += result }