Skip to content

Commit

Permalink
minor: refactor decodeBatches to make private in broadcast exchange
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 22, 2024
1 parent ea6d205 commit 65f46a1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@

package org.apache.spark.sql.comet

import java.io.DataInputStream
import java.nio.channels.Channels
import java.util.UUID
import java.util.concurrent.{Future, TimeoutException, TimeUnit}

import scala.concurrent.{ExecutionContext, Promise}
import scala.concurrent.duration.NANOSECONDS
import scala.util.control.NonFatal

import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, TaskContext}
import org.apache.spark.comet.shims.ShimCometBroadcastExchangeExec
import org.apache.spark.io.CompressionCodec
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.comet.execution.shuffle.ArrowReaderIterator
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
Expand Down Expand Up @@ -299,7 +303,23 @@ class CometBatchRDD(
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]
partition.value.value.toIterator
.flatMap(CometExec.decodeBatches(_, this.getClass.getSimpleName))
.flatMap(decodeBatches(_, this.getClass.getSimpleName))
}

/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
*/
private def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

// use Spark's compression codec (LZ4 by default) and not Comet's compression
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))
// batches are in Arrow IPC format
new ArrowReaderIterator(Channels.newChannel(ins), source)
}
}

Expand Down
35 changes: 3 additions & 32 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@

package org.apache.spark.sql.comet

import java.io.{ByteArrayOutputStream, DataInputStream}
import java.nio.channels.Channels
import java.io.ByteArrayOutputStream

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
Expand Down Expand Up @@ -78,18 +76,6 @@ abstract class CometExec extends CometPlan {
// outputPartitioning of SparkPlan, e.g., AQEShuffleReadExec.
override def outputPartitioning: Partitioning = originalPlan.outputPartitioning

/**
* Executes the Comet operator and returns the result as an iterator of ColumnarBatch.
*/
def executeColumnarCollectIterator(): (Long, Iterator[ColumnarBatch]) = {
val countsAndBytes = CometExec.getByteArrayRdd(this).collect()
val total = countsAndBytes.map(_._1).sum
val rows = countsAndBytes.iterator
.flatMap(countAndBytes =>
CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
(total, rows)
}

protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
sparkPlan.children.foreach(setSubqueries(planId, _))

Expand Down Expand Up @@ -161,21 +147,6 @@ object CometExec {
Utils.serializeBatches(iter)
}
}

/**
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
*/
def decodeBatches(bytes: ChunkedByteBuffer, source: String): Iterator[ColumnarBatch] = {
if (bytes.size == 0) {
return Iterator.empty
}

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbis = bytes.toInputStream()
val ins = new DataInputStream(codec.compressedInputStream(cbbis))

new ArrowReaderIterator(Channels.newChannel(ins), source)
}
}

/**
Expand Down
33 changes: 0 additions & 33 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ package org.apache.comet.exec
import java.sql.Date
import java.time.{Duration, Period}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random

import org.scalactic.source.Position
Expand Down Expand Up @@ -462,37 +460,6 @@ class CometExecSuite extends CometTestBase {
}
}

test("CometExec.executeColumnarCollectIterator can collect ColumnarBatch results") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+")
withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "true") {
withParquetTable((0 until 50).map(i => (i, i + 1)), "tbl") {
val df = sql("SELECT _1 + 1, _2 + 2 FROM tbl WHERE _1 > 3")

val nativeProject = find(df.queryExecution.executedPlan) {
case _: CometProjectExec => true
case _ => false
}.get.asInstanceOf[CometProjectExec]

val (rows, batches) = nativeProject.executeColumnarCollectIterator()
assert(rows == 46)

val column1 = mutable.ArrayBuffer.empty[Int]
val column2 = mutable.ArrayBuffer.empty[Int]

batches.foreach(batch => {
batch.rowIterator().asScala.foreach { row =>
assert(row.numFields == 2)
column1 += row.getInt(0)
column2 += row.getInt(1)
}
})

assert(column1.toArray.sorted === (4 until 50).map(_ + 1).toArray)
assert(column2.toArray.sorted === (5 until 51).map(_ + 2).toArray)
}
}
}

test("scalar subquery") {
val dataTypes =
Seq(
Expand Down

0 comments on commit 65f46a1

Please sign in to comment.