Skip to content

Commit

Permalink
test: Copy Spark TPCDSQueryTestSuite to CometTPCDSQueryTestSuite (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Jul 5, 2024
1 parent eff2897 commit 335146e
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql

import scala.util.control.NonFatal

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase}
import org.apache.spark.sql.types.StructType

trait CometSQLQueryTestHelper {

private val notIncludedMsg = "[not included in comparison]"
private val clsName = this.getClass.getCanonicalName
protected val emptySchema: String = StructType(Seq.empty).catalogString

protected def replaceNotIncludedMsg(line: String): String = {
line
.replaceAll("#\\d+", "#x")
.replaceAll("plan_id=\\d+", "plan_id=x")
.replaceAll(s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/")
.replaceAll(s"file:[^\\s,]*$clsName", s"file:$notIncludedMsg/{warehouse_dir}")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg")
.replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds
}

/** Executes a query and returns the result as (schema of the output, normalized output). */
protected def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation |
_: DescribeColumn =>
true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

val df = session.sql(sql)
val schema = df.schema.catalogString
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) {
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
}

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
}

/**
* This method handles exceptions occurred during query execution as they may need special care
* to become comparable to the expected output.
*
* @param result
* a function that returns a pair of schema and output
*/
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
try {
result
} catch {
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
(emptySchema, Seq(e.getClass.getName, e.getMessage))
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = a.getMessage
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
case s: SparkException if s.getCause != null =>
// For a runtime exception, it is hard to match because its message contains
// information of stage, task ID, etc.
// To make result matching simpler, here we match the cause of the exception if it exists.
s.getCause match {
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
(emptySchema, Seq(e.getClass.getName, e.getMessage))
case cause =>
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
}
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class CometTPCDSQuerySuite
override val tpcdsQueries: Seq[String] =
tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)
}
with TPCDSQueryTestSuite
with CometTPCDSQueryTestSuite
with ShimCometTPCDSQuerySuite {
override def sparkConf: SparkConf = {
val conf = super.sparkConf
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql

import java.io.File
import java.nio.file.{Files, Paths}

import scala.collection.JavaConverters._

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.TestSparkSession

/**
* Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we
* copy Spark `TPCDSQueryTestSuite`.
*/
class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper {

private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA")

private val regenGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"

// To make output results deterministic
override protected def sparkConf: SparkConf = super.sparkConf
.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")

protected override def createSparkSession: TestSparkSession = {
new TestSparkSession(new SparkContext("local[1]", this.getClass.getSimpleName, sparkConf))
}

// We use SF=1 table data here, so we cannot use SF=100 stats
protected override val injectStats: Boolean = false

if (tpcdsDataPath.nonEmpty) {
val nonExistentTables = tableNames.filterNot { tableName =>
Files.exists(Paths.get(s"${tpcdsDataPath.get}/$tableName"))
}
if (nonExistentTables.nonEmpty) {
fail(
s"Non-existent TPCDS table paths found in ${tpcdsDataPath.get}: " +
nonExistentTables.mkString(", "))
}
}

protected val baseResourcePath: String = {
// use the same way as `SQLQueryTestSuite` to get the resource path
getWorkspaceFilePath(
"sql",
"core",
"src",
"test",
"resources",
"tpcds-query-results").toFile.getAbsolutePath
}

override def createTable(
spark: SparkSession,
tableName: String,
format: String = "parquet",
options: scala.Seq[String]): Unit = {
spark.sql(s"""
|CREATE TABLE `$tableName` (${tableColumns(tableName)})
|USING $format
|LOCATION '${tpcdsDataPath.get}/$tableName'
|${options.mkString("\n")}
""".stripMargin)
}

private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = {
// This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins
// than sort merge join. But in some queries DataFusion sort returns correct results
// in terms of required sorting columns, but the results are not same as Spark in terms of
// order of irrelevant columns. So, we need to sort the results for all joins.
val shouldSortResults = true
withSQLConf(conf.toSeq: _*) {
try {
val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
val queryString = query.trim
val outputString = output.mkString("\n").replaceAll("\\s+$", "")
if (regenGoldenFiles) {
val goldenOutput = {
s"-- Automatically generated by ${getClass.getSimpleName}\n\n" +
"-- !query schema\n" +
schema + "\n" +
"-- !query output\n" +
outputString +
"\n"
}
val parent = goldenFile.getParentFile
if (!parent.exists()) {
assert(parent.mkdirs(), "Could not create directory: " + parent)
}
stringToFile(goldenFile, goldenOutput)
}

// Read back the golden file.
val (expectedSchema, expectedOutput) = {
val goldenOutput = fileToString(goldenFile)
val segments = goldenOutput.split("-- !query.*\n")

// query has 3 segments, plus the header
assert(
segments.size == 3,
s"Expected 3 blocks in result file but got ${segments.size}. " +
"Try regenerate the result files.")

(segments(1).trim, segments(2).replaceAll("\\s+$", ""))
}

val notMatchedSchemaOutput = if (schema == emptySchema) {
// There might be exception. See `handleExceptions`.
s"Schema did not match\n$queryString\nOutput/Exception: $outputString"
} else {
s"Schema did not match\n$queryString"
}

assertResult(expectedSchema, notMatchedSchemaOutput) {
schema
}
if (shouldSortResults) {
val expectSorted = expectedOutput
.split("\n")
.sorted
.map(_.trim)
.mkString("\n")
.replaceAll("\\s+$", "")
val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "")
assertResult(expectSorted, s"Result did not match\n$queryString") {
outputSorted
}
} else {
assertResult(expectedOutput, s"Result did not match\n$queryString") {
outputString
}
}
} catch {
case e: Throwable =>
val configs = conf.map { case (k, v) =>
s"$k=$v"
}
throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}")
}
}
}

val sortMergeJoinConf: Map[String, String] = Map(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true")

val broadcastHashJoinConf: Map[String, String] = Map(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760")

val shuffledHashJoinConf: Map[String, String] = Map(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
"spark.sql.join.forceApplyShuffledHashJoin" -> "true")

val allJoinConfCombinations: Seq[Map[String, String]] =
Seq(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf)

val joinConfs: Seq[Map[String, String]] = if (regenGoldenFiles) {
require(
!sys.env.contains("SPARK_TPCDS_JOIN_CONF"),
"'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'")
Seq(sortMergeJoinConf)
} else {
sys.env
.get("SPARK_TPCDS_JOIN_CONF")
.map { s =>
val p = new java.util.Properties()
p.load(new java.io.StringReader(s))
Seq(p.asScala.toMap)
}
.getOrElse(allJoinConfCombinations)
}

assert(joinConfs.nonEmpty)
joinConfs.foreach(conf =>
require(
allJoinConfCombinations.contains(conf),
s"Join configurations [$conf] should be one of $allJoinConfCombinations"))

if (tpcdsDataPath.nonEmpty) {
tpcdsQueries.foreach { name =>
val queryString = resourceToString(
s"tpcds/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(name) {
val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out")
joinConfs.foreach { conf =>
System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368
runQuery(queryString, goldenFile, conf)
}
}
}

tpcdsQueriesV2_7_0.foreach { name =>
val queryString = resourceToString(
s"tpcds-v2.7.0/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(s"$name-v2.7") {
val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out")
joinConfs.foreach { conf =>
System.gc() // SPARK-37368
runQuery(queryString, goldenFile, conf)
}
}
}
} else {
ignore("skipped because env `SPARK_TPCDS_DATA` is not set") {}
}
}

0 comments on commit 335146e

Please sign in to comment.