forked from apache/datafusion-comet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: Copy Spark TPCDSQueryTestSuite to CometTPCDSQueryTestSuite (apa…
- Loading branch information
Showing
3 changed files
with
339 additions
and
1 deletion.
There are no files selected for viewing
108 changes: 108 additions & 0 deletions
108
spark/src/test/scala/org/apache/spark/sql/CometSQLQueryTestHelper.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql | ||
|
||
import scala.util.control.NonFatal | ||
|
||
import org.apache.spark.{SparkException, SparkThrowable} | ||
import org.apache.spark.sql.catalyst.planning.PhysicalOperation | ||
import org.apache.spark.sql.catalyst.plans.logical._ | ||
import org.apache.spark.sql.execution.HiveResult.hiveResultString | ||
import org.apache.spark.sql.execution.SQLExecution | ||
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase} | ||
import org.apache.spark.sql.types.StructType | ||
|
||
trait CometSQLQueryTestHelper { | ||
|
||
private val notIncludedMsg = "[not included in comparison]" | ||
private val clsName = this.getClass.getCanonicalName | ||
protected val emptySchema: String = StructType(Seq.empty).catalogString | ||
|
||
protected def replaceNotIncludedMsg(line: String): String = { | ||
line | ||
.replaceAll("#\\d+", "#x") | ||
.replaceAll("plan_id=\\d+", "plan_id=x") | ||
.replaceAll(s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/") | ||
.replaceAll(s"file:[^\\s,]*$clsName", s"file:$notIncludedMsg/{warehouse_dir}") | ||
.replaceAll("Created By.*", s"Created By $notIncludedMsg") | ||
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg") | ||
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg") | ||
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") | ||
.replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds | ||
} | ||
|
||
/** Executes a query and returns the result as (schema of the output, normalized output). */ | ||
protected def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = { | ||
// Returns true if the plan is supposed to be sorted. | ||
def isSorted(plan: LogicalPlan): Boolean = plan match { | ||
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false | ||
case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation | | ||
_: DescribeColumn => | ||
true | ||
case PhysicalOperation(_, _, Sort(_, true, _)) => true | ||
case _ => plan.children.iterator.exists(isSorted) | ||
} | ||
|
||
val df = session.sql(sql) | ||
val schema = df.schema.catalogString | ||
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans | ||
val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) { | ||
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) | ||
} | ||
|
||
// If the output is not pre-sorted, sort it. | ||
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) | ||
} | ||
|
||
/** | ||
* This method handles exceptions occurred during query execution as they may need special care | ||
* to become comparable to the expected output. | ||
* | ||
* @param result | ||
* a function that returns a pair of schema and output | ||
*/ | ||
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = { | ||
try { | ||
result | ||
} catch { | ||
case e: SparkThrowable with Throwable if e.getErrorClass != null => | ||
(emptySchema, Seq(e.getClass.getName, e.getMessage)) | ||
case a: AnalysisException => | ||
// Do not output the logical plan tree which contains expression IDs. | ||
// Also implement a crude way of masking expression IDs in the error message | ||
// with a generic pattern "###". | ||
val msg = a.getMessage | ||
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) | ||
case s: SparkException if s.getCause != null => | ||
// For a runtime exception, it is hard to match because its message contains | ||
// information of stage, task ID, etc. | ||
// To make result matching simpler, here we match the cause of the exception if it exists. | ||
s.getCause match { | ||
case e: SparkThrowable with Throwable if e.getErrorClass != null => | ||
(emptySchema, Seq(e.getClass.getName, e.getMessage)) | ||
case cause => | ||
(emptySchema, Seq(cause.getClass.getName, cause.getMessage)) | ||
} | ||
case NonFatal(e) => | ||
// If there is an exception, put the exception class followed by the message. | ||
(emptySchema, Seq(e.getClass.getName, e.getMessage)) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
230 changes: 230 additions & 0 deletions
230
spark/src/test/scala/org/apache/spark/sql/CometTPCDSQueryTestSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql | ||
|
||
import java.io.File | ||
import java.nio.file.{Files, Paths} | ||
|
||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.{SparkConf, SparkContext} | ||
import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile} | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.test.TestSparkSession | ||
|
||
/** | ||
* Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we | ||
* copy Spark `TPCDSQueryTestSuite`. | ||
*/ | ||
class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper { | ||
|
||
private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA") | ||
|
||
private val regenGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" | ||
|
||
// To make output results deterministic | ||
override protected def sparkConf: SparkConf = super.sparkConf | ||
.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") | ||
|
||
protected override def createSparkSession: TestSparkSession = { | ||
new TestSparkSession(new SparkContext("local[1]", this.getClass.getSimpleName, sparkConf)) | ||
} | ||
|
||
// We use SF=1 table data here, so we cannot use SF=100 stats | ||
protected override val injectStats: Boolean = false | ||
|
||
if (tpcdsDataPath.nonEmpty) { | ||
val nonExistentTables = tableNames.filterNot { tableName => | ||
Files.exists(Paths.get(s"${tpcdsDataPath.get}/$tableName")) | ||
} | ||
if (nonExistentTables.nonEmpty) { | ||
fail( | ||
s"Non-existent TPCDS table paths found in ${tpcdsDataPath.get}: " + | ||
nonExistentTables.mkString(", ")) | ||
} | ||
} | ||
|
||
protected val baseResourcePath: String = { | ||
// use the same way as `SQLQueryTestSuite` to get the resource path | ||
getWorkspaceFilePath( | ||
"sql", | ||
"core", | ||
"src", | ||
"test", | ||
"resources", | ||
"tpcds-query-results").toFile.getAbsolutePath | ||
} | ||
|
||
override def createTable( | ||
spark: SparkSession, | ||
tableName: String, | ||
format: String = "parquet", | ||
options: scala.Seq[String]): Unit = { | ||
spark.sql(s""" | ||
|CREATE TABLE `$tableName` (${tableColumns(tableName)}) | ||
|USING $format | ||
|LOCATION '${tpcdsDataPath.get}/$tableName' | ||
|${options.mkString("\n")} | ||
""".stripMargin) | ||
} | ||
|
||
private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = { | ||
// This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins | ||
// than sort merge join. But in some queries DataFusion sort returns correct results | ||
// in terms of required sorting columns, but the results are not same as Spark in terms of | ||
// order of irrelevant columns. So, we need to sort the results for all joins. | ||
val shouldSortResults = true | ||
withSQLConf(conf.toSeq: _*) { | ||
try { | ||
val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) | ||
val queryString = query.trim | ||
val outputString = output.mkString("\n").replaceAll("\\s+$", "") | ||
if (regenGoldenFiles) { | ||
val goldenOutput = { | ||
s"-- Automatically generated by ${getClass.getSimpleName}\n\n" + | ||
"-- !query schema\n" + | ||
schema + "\n" + | ||
"-- !query output\n" + | ||
outputString + | ||
"\n" | ||
} | ||
val parent = goldenFile.getParentFile | ||
if (!parent.exists()) { | ||
assert(parent.mkdirs(), "Could not create directory: " + parent) | ||
} | ||
stringToFile(goldenFile, goldenOutput) | ||
} | ||
|
||
// Read back the golden file. | ||
val (expectedSchema, expectedOutput) = { | ||
val goldenOutput = fileToString(goldenFile) | ||
val segments = goldenOutput.split("-- !query.*\n") | ||
|
||
// query has 3 segments, plus the header | ||
assert( | ||
segments.size == 3, | ||
s"Expected 3 blocks in result file but got ${segments.size}. " + | ||
"Try regenerate the result files.") | ||
|
||
(segments(1).trim, segments(2).replaceAll("\\s+$", "")) | ||
} | ||
|
||
val notMatchedSchemaOutput = if (schema == emptySchema) { | ||
// There might be exception. See `handleExceptions`. | ||
s"Schema did not match\n$queryString\nOutput/Exception: $outputString" | ||
} else { | ||
s"Schema did not match\n$queryString" | ||
} | ||
|
||
assertResult(expectedSchema, notMatchedSchemaOutput) { | ||
schema | ||
} | ||
if (shouldSortResults) { | ||
val expectSorted = expectedOutput | ||
.split("\n") | ||
.sorted | ||
.map(_.trim) | ||
.mkString("\n") | ||
.replaceAll("\\s+$", "") | ||
val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "") | ||
assertResult(expectSorted, s"Result did not match\n$queryString") { | ||
outputSorted | ||
} | ||
} else { | ||
assertResult(expectedOutput, s"Result did not match\n$queryString") { | ||
outputString | ||
} | ||
} | ||
} catch { | ||
case e: Throwable => | ||
val configs = conf.map { case (k, v) => | ||
s"$k=$v" | ||
} | ||
throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}") | ||
} | ||
} | ||
} | ||
|
||
val sortMergeJoinConf: Map[String, String] = Map( | ||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | ||
SQLConf.PREFER_SORTMERGEJOIN.key -> "true") | ||
|
||
val broadcastHashJoinConf: Map[String, String] = Map( | ||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760") | ||
|
||
val shuffledHashJoinConf: Map[String, String] = Map( | ||
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | ||
"spark.sql.join.forceApplyShuffledHashJoin" -> "true") | ||
|
||
val allJoinConfCombinations: Seq[Map[String, String]] = | ||
Seq(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf) | ||
|
||
val joinConfs: Seq[Map[String, String]] = if (regenGoldenFiles) { | ||
require( | ||
!sys.env.contains("SPARK_TPCDS_JOIN_CONF"), | ||
"'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'") | ||
Seq(sortMergeJoinConf) | ||
} else { | ||
sys.env | ||
.get("SPARK_TPCDS_JOIN_CONF") | ||
.map { s => | ||
val p = new java.util.Properties() | ||
p.load(new java.io.StringReader(s)) | ||
Seq(p.asScala.toMap) | ||
} | ||
.getOrElse(allJoinConfCombinations) | ||
} | ||
|
||
assert(joinConfs.nonEmpty) | ||
joinConfs.foreach(conf => | ||
require( | ||
allJoinConfCombinations.contains(conf), | ||
s"Join configurations [$conf] should be one of $allJoinConfCombinations")) | ||
|
||
if (tpcdsDataPath.nonEmpty) { | ||
tpcdsQueries.foreach { name => | ||
val queryString = resourceToString( | ||
s"tpcds/$name.sql", | ||
classLoader = Thread.currentThread().getContextClassLoader) | ||
test(name) { | ||
val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out") | ||
joinConfs.foreach { conf => | ||
System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368 | ||
runQuery(queryString, goldenFile, conf) | ||
} | ||
} | ||
} | ||
|
||
tpcdsQueriesV2_7_0.foreach { name => | ||
val queryString = resourceToString( | ||
s"tpcds-v2.7.0/$name.sql", | ||
classLoader = Thread.currentThread().getContextClassLoader) | ||
test(s"$name-v2.7") { | ||
val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out") | ||
joinConfs.foreach { conf => | ||
System.gc() // SPARK-37368 | ||
runQuery(queryString, goldenFile, conf) | ||
} | ||
} | ||
} | ||
} else { | ||
ignore("skipped because env `SPARK_TPCDS_DATA` is not set") {} | ||
} | ||
} |