Skip to content

Commit

Permalink
record decodeTime in reader
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 22, 2024
1 parent 44abab0 commit db1649c
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ object CometMetricNode {
"mempool_time" -> SQLMetrics.createNanoTimingMetric(sc, "memory pool time"),
"repart_time" -> SQLMetrics.createNanoTimingMetric(sc, "repartition time"),
"ipc_time" -> SQLMetrics.createNanoTimingMetric(sc, "encoding and compression time"),
"decodeTime" -> SQLMetrics.createNanoTimingMetric(sc, "decoding and decompression time"),
"spill_count" -> SQLMetrics.createMetric(sc, "number of spills"),
"spilled_bytes" -> SQLMetrics.createMetric(sc, "spilled bytes"),
"input_batches" -> SQLMetrics.createMetric(sc, "number of input batches"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ class CometBlockStoreShuffleReader[K, C](
if (currentReadIterator != null) {
currentReadIterator.close()
}
currentReadIterator = NativeBatchDecoderIterator(blockIdAndStream._2, context)
currentReadIterator =
NativeBatchDecoderIterator(blockIdAndStream._2, context, dep.decodeTime)
currentReadIterator
})
.map(b => (0, b))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency, SparkEnv}
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleWriteProcessor
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType

/**
Expand All @@ -39,7 +40,8 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
override val mapSideCombine: Boolean = false,
override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
val shuffleType: ShuffleType = CometNativeShuffle,
val schema: Option[StructType] = None)
val schema: Option[StructType] = None,
val decodeTime: SQLMetric)
extends ShuffleDependency[K, V, C](
_rdd,
partitioner,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
})
},
decodeTime = metrics("decodeTime"))
dependency
}

Expand Down Expand Up @@ -435,7 +436,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
serializer,
shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(writeMetrics),
shuffleType = CometColumnarShuffle,
schema = Some(fromAttributes(outputAttributes)))
schema = Some(fromAttributes(outputAttributes)),
decodeTime = writeMetrics("decodeTime"))

dependency
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import java.nio.channels.{Channels, ReadableByteChannel}

import org.apache.spark.TaskContext
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.Native
Expand All @@ -34,7 +35,10 @@ import org.apache.comet.vector.NativeUtil
* native ShuffleWriterExec and then calls native code to decompress and decode the shuffle blocks
* and use Arrow FFI to return the Arrow record batch.
*/
case class NativeBatchDecoderIterator(var in: InputStream, taskContext: TaskContext)
case class NativeBatchDecoderIterator(
var in: InputStream,
taskContext: TaskContext,
decodeTime: SQLMetric)
extends Iterator[ColumnarBatch] {
private val SPARK_LZ4_MAGIC = Array[Byte](76, 90, 52, 66, 108, 111, 99, 107) // "LZ4Block"
private var nextBatch: Option[ColumnarBatch] = None
Expand Down Expand Up @@ -105,11 +109,13 @@ case class NativeBatchDecoderIterator(var in: InputStream, taskContext: TaskCont
while (buffer.hasRemaining && channel.read(buffer) >= 0) {}

// make native call to decode batch
val startTime = System.nanoTime()
nextBatch = nativeUtil.getNextBatch(
fieldCount,
(arrayAddrs, schemaAddrs) => {
native.decodeShuffleBlock(buffer, arrayAddrs, schemaAddrs)
})
decodeTime.add(System.nanoTime() - startTime)

true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,8 @@ class CometShuffleManagerSuite extends CometTestBase {
partitioner = new Partitioner {
override def numPartitions: Int = 50
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
})
},
decodeTime = null)

assert(CometShuffleManager.shouldBypassMergeSort(conf, dependency))

Expand Down

0 comments on commit db1649c

Please sign in to comment.