Skip to content

Commit

Permalink
reuse buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 21, 2024
1 parent f7c9407 commit f218870
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 25 deletions.
22 changes: 9 additions & 13 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use jni::{
use tokio::runtime::Runtime;

use crate::execution::operators::ScanExec;
use crate::execution::shuffle::read_ipc_compressed_zstd;
use crate::execution::shuffle::read_ipc_compressed;
use crate::execution::spark_plan::SparkPlan;
use log::info;

Expand Down Expand Up @@ -95,7 +95,7 @@ struct ExecutionContext {

/// Accept serialized query plan and return the address of the native query plan.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
Expand Down Expand Up @@ -295,7 +295,7 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr
/// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
/// then execute the query. Return addresses of arrow vector.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
Expand Down Expand Up @@ -458,7 +458,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext {

/// Used by Comet shuffle external sorter to write sorted records to disk.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed from JNI.
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
#[no_mangle]
pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative(
e: JNIEnv,
Expand Down Expand Up @@ -546,7 +546,9 @@ pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(

#[no_mangle]
/// Used by Comet native shuffle reader
pub extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
/// # Safety
/// This function is inherently unsafe since it deals with raw pointers passed from JNI.
pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
e: JNIEnv,
_class: JClass,
byte_array: jbyteArray,
Expand All @@ -559,13 +561,7 @@ pub extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
let elements = unsafe { env.get_array_elements(&value_array, ReleaseMode::NoCopyBack)? };
let raw_pointer = elements.as_ptr();
let slice = unsafe { std::slice::from_raw_parts(raw_pointer, length as usize) };
let batch = read_ipc_compressed_zstd(slice)?;
Ok(prepare_output(
&mut env,
array_addrs,
schema_addrs,
batch,
false,
)?)
let batch = read_ipc_compressed(slice)?;
prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
})
}
2 changes: 1 addition & 1 deletion native/core/src/execution/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ mod map;
pub mod row;
mod shuffle_writer;
pub use shuffle_writer::{
read_ipc_compressed_zstd, write_ipc_compressed, CompressionCodec, ShuffleWriterExec,
read_ipc_compressed, write_ipc_compressed, CompressionCodec, ShuffleWriterExec,
};
4 changes: 2 additions & 2 deletions native/core/src/execution/shuffle/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ pub fn write_ipc_compressed<W: Write + Seek>(
Ok(compressed_length as usize)
}

pub fn read_ipc_compressed_zstd(bytes: &[u8]) -> Result<RecordBatch> {
pub fn read_ipc_compressed(bytes: &[u8]) -> Result<RecordBatch> {
let decoder = zstd::Decoder::new(bytes)?;
let mut reader = StreamReader::try_new(decoder, None)?;
// TODO check for None
Expand Down Expand Up @@ -1675,7 +1675,7 @@ mod test {
assert_eq!(40210, length);

let ipc_without_length_prefix = &output[16..];
let batch2 = read_ipc_compressed_zstd(ipc_without_length_prefix).unwrap();
let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap();
assert_eq!(batch, batch2);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,25 @@ package org.apache.spark.sql.comet.shuffle
import java.io.{EOFException, InputStream}
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.{Channels, ReadableByteChannel}

import org.apache.spark.TaskContext
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.Native
import org.apache.comet.vector.NativeUtil

import java.nio.channels.{Channels, ReadableByteChannel}

/**
* This iterator wraps a Spark input stream that is reading shuffle blocks generated by the Comet
* native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks
* and use Arrow FFI to return the Arrow record batch.
*/
case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskContext)
extends Iterator[ColumnarBatch] {
private var nextBatch: Option[ColumnarBatch] = None
private var finished = false;
private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
private var messageBuf = new Array[Byte](8192 * 512)
private val native = new Native()
private val nativeUtil = new NativeUtil()

Expand Down Expand Up @@ -82,15 +89,14 @@ case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskCon
val fieldCount = longBuf.getLong.toInt

// read body
// TODO reuse buffer
val buffer = new Array[Byte](compressedLength.toInt)
fillBuffer(in, buffer)
ensureCapacity(compressedLength)
fillBuffer(in, messageBuf, compressedLength)

// make native call to decode batch
nextBatch = nativeUtil.getNextBatch(
fieldCount,
(arrayAddrs, schemaAddrs) => {
native.decodeShuffleBlock(buffer, arrayAddrs, schemaAddrs)
native.decodeShuffleBlock(messageBuf, arrayAddrs, schemaAddrs)
})

true
Expand All @@ -106,10 +112,19 @@ case class ShuffleBatchDecoderIterator(var in: InputStream, taskContext: TaskCon
}
}

private def fillBuffer(in: InputStream, buffer: Array[Byte]): Unit = {
private def ensureCapacity(requiredSize: Int): Unit = {
if (messageBuf.length < requiredSize) {
val newSize = Math.max(messageBuf.length * 2, requiredSize)
val newBuffer = new Array[Byte](newSize)
Array.copy(messageBuf, 0, newBuffer, 0, messageBuf.length)
messageBuf = newBuffer
}
}

private def fillBuffer(in: InputStream, buffer: Array[Byte], len: Int): Unit = {
var bytesRead = 0
while (bytesRead < buffer.length) {
val result = in.read(buffer, bytesRead, buffer.length - bytesRead)
while (bytesRead < len) {
val result = in.read(buffer, bytesRead, len - bytesRead)
if (result == -1) throw new EOFException()
bytesRead += result
}
Expand Down

0 comments on commit f218870

Please sign in to comment.