diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala index 7b6232eeb..b5dc51533 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala @@ -44,6 +44,7 @@ case class NativeBatchDecoderIterator( private var nextBatch: Option[ColumnarBatch] = None private var finished = false; private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN) + private var dataBuf: ByteBuffer = _ private val native = new Native() private val nativeUtil = new NativeUtil() @@ -104,16 +105,20 @@ case class NativeBatchDecoderIterator( val fieldCount = longBuf.getLong.toInt // read body - // TODO avoid allocating a new buffer for each batch - val buffer = ByteBuffer.allocateDirect(compressedLength - 8) - while (buffer.hasRemaining && channel.read(buffer) >= 0) {} + val bytesToRead = compressedLength - 8 + if (dataBuf == null || dataBuf.capacity() < bytesToRead) { + dataBuf = ByteBuffer.allocateDirect(bytesToRead * 2) + } + dataBuf.clear() + dataBuf.limit(bytesToRead) + while (dataBuf.hasRemaining && channel.read(dataBuf) >= 0) {} // make native call to decode batch val startTime = System.nanoTime() nextBatch = nativeUtil.getNextBatch( fieldCount, (arrayAddrs, schemaAddrs) => { - native.decodeShuffleBlock(buffer, arrayAddrs, schemaAddrs) + native.decodeShuffleBlock(dataBuf, arrayAddrs, schemaAddrs) }) decodeTime.add(System.nanoTime() - startTime)