diff --git a/spark/src/test/scala/org/apache/spark/sql/CometSQLQueryTestHelper.scala b/spark/src/test/scala/org/apache/spark/sql/CometSQLQueryTestHelper.scala new file mode 100644 index 000000000..bf5ed4396 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/CometSQLQueryTestHelper.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.HiveResult.hiveResultString +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase} +import org.apache.spark.sql.types.StructType + +trait CometSQLQueryTestHelper { + + private val notIncludedMsg = "[not included in comparison]" + private val clsName = this.getClass.getCanonicalName + protected val emptySchema: String = StructType(Seq.empty).catalogString + + protected def replaceNotIncludedMsg(line: String): String = { + line + .replaceAll("#\\d+", "#x") + .replaceAll("plan_id=\\d+", "plan_id=x") + .replaceAll(s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/") + .replaceAll(s"file:[^\\s,]*$clsName", s"file:$notIncludedMsg/{warehouse_dir}") + .replaceAll("Created By.*", s"Created By $notIncludedMsg") + .replaceAll("Created Time.*", s"Created Time $notIncludedMsg") + .replaceAll("Last Access.*", s"Last Access $notIncludedMsg") + .replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg") + .replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds + } + + /** Executes a query and returns the result as (schema of the output, normalized output). */ + protected def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = { + // Returns true if the plan is supposed to be sorted. + def isSorted(plan: LogicalPlan): Boolean = plan match { + case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false + case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation | + _: DescribeColumn => + true + case PhysicalOperation(_, _, Sort(_, true, _)) => true + case _ => plan.children.iterator.exists(isSorted) + } + + val df = session.sql(sql) + val schema = df.schema.catalogString + // Get answer, but also get rid of the #1234 expression ids that show up in explain plans + val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) { + hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) + } + + // If the output is not pre-sorted, sort it. + if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) + } + + /** + * This method handles exceptions occurred during query execution as they may need special care + * to become comparable to the expected output. + * + * @param result + * a function that returns a pair of schema and output + */ + protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = { + try { + result + } catch { + case e: SparkThrowable with Throwable if e.getErrorClass != null => + (emptySchema, Seq(e.getClass.getName, e.getMessage)) + case a: AnalysisException => + // Do not output the logical plan tree which contains expression IDs. + // Also implement a crude way of masking expression IDs in the error message + // with a generic pattern "###". + val msg = a.getMessage + (emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) + case s: SparkException if s.getCause != null => + // For a runtime exception, it is hard to match because its message contains + // information of stage, task ID, etc. + // To make result matching simpler, here we match the cause of the exception if it exists. + s.getCause match { + case e: SparkThrowable with Throwable if e.getErrorClass != null => + (emptySchema, Seq(e.getClass.getName, e.getMessage)) + case cause => + (emptySchema, Seq(cause.getClass.getName, cause.getMessage)) + } + case NonFatal(e) => + // If there is an exception, put the exception class followed by the message. + (emptySchema, Seq(e.getClass.getName, e.getMessage)) + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala index 53186b131..3e0f64522 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQuerySuite.scala @@ -145,7 +145,7 @@ class CometTPCDSQuerySuite override val tpcdsQueries: Seq[String] = tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains) } - with TPCDSQueryTestSuite + with CometTPCDSQueryTestSuite with ShimCometTPCDSQuerySuite { override def sparkConf: SparkConf = { val conf = super.sparkConf diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQueryTestSuite.scala b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQueryTestSuite.scala new file mode 100644 index 000000000..c2b853515 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/CometTPCDSQueryTestSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql + +import java.io.File +import java.nio.file.{Files, Paths} + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.TestSparkSession + +/** + * Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we + * copy Spark `TPCDSQueryTestSuite`. + */ +class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper { + + private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA") + + private val regenGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1" + + // To make output results deterministic + override protected def sparkConf: SparkConf = super.sparkConf + .set(SQLConf.SHUFFLE_PARTITIONS.key, "1") + + protected override def createSparkSession: TestSparkSession = { + new TestSparkSession(new SparkContext("local[1]", this.getClass.getSimpleName, sparkConf)) + } + + // We use SF=1 table data here, so we cannot use SF=100 stats + protected override val injectStats: Boolean = false + + if (tpcdsDataPath.nonEmpty) { + val nonExistentTables = tableNames.filterNot { tableName => + Files.exists(Paths.get(s"${tpcdsDataPath.get}/$tableName")) + } + if (nonExistentTables.nonEmpty) { + fail( + s"Non-existent TPCDS table paths found in ${tpcdsDataPath.get}: " + + nonExistentTables.mkString(", ")) + } + } + + protected val baseResourcePath: String = { + // use the same way as `SQLQueryTestSuite` to get the resource path + getWorkspaceFilePath( + "sql", + "core", + "src", + "test", + "resources", + "tpcds-query-results").toFile.getAbsolutePath + } + + override def createTable( + spark: SparkSession, + tableName: String, + format: String = "parquet", + options: scala.Seq[String]): Unit = { + spark.sql(s""" + |CREATE TABLE `$tableName` (${tableColumns(tableName)}) + |USING $format + |LOCATION '${tpcdsDataPath.get}/$tableName' + |${options.mkString("\n")} + """.stripMargin) + } + + private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = { + // This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins + // than sort merge join. But in some queries DataFusion sort returns correct results + // in terms of required sorting columns, but the results are not same as Spark in terms of + // order of irrelevant columns. So, we need to sort the results for all joins. + val shouldSortResults = true + withSQLConf(conf.toSeq: _*) { + try { + val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) + val queryString = query.trim + val outputString = output.mkString("\n").replaceAll("\\s+$", "") + if (regenGoldenFiles) { + val goldenOutput = { + s"-- Automatically generated by ${getClass.getSimpleName}\n\n" + + "-- !query schema\n" + + schema + "\n" + + "-- !query output\n" + + outputString + + "\n" + } + val parent = goldenFile.getParentFile + if (!parent.exists()) { + assert(parent.mkdirs(), "Could not create directory: " + parent) + } + stringToFile(goldenFile, goldenOutput) + } + + // Read back the golden file. + val (expectedSchema, expectedOutput) = { + val goldenOutput = fileToString(goldenFile) + val segments = goldenOutput.split("-- !query.*\n") + + // query has 3 segments, plus the header + assert( + segments.size == 3, + s"Expected 3 blocks in result file but got ${segments.size}. " + + "Try regenerate the result files.") + + (segments(1).trim, segments(2).replaceAll("\\s+$", "")) + } + + val notMatchedSchemaOutput = if (schema == emptySchema) { + // There might be exception. See `handleExceptions`. + s"Schema did not match\n$queryString\nOutput/Exception: $outputString" + } else { + s"Schema did not match\n$queryString" + } + + assertResult(expectedSchema, notMatchedSchemaOutput) { + schema + } + if (shouldSortResults) { + val expectSorted = expectedOutput + .split("\n") + .sorted + .map(_.trim) + .mkString("\n") + .replaceAll("\\s+$", "") + val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "") + assertResult(expectSorted, s"Result did not match\n$queryString") { + outputSorted + } + } else { + assertResult(expectedOutput, s"Result did not match\n$queryString") { + outputString + } + } + } catch { + case e: Throwable => + val configs = conf.map { case (k, v) => + s"$k=$v" + } + throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}") + } + } + } + + val sortMergeJoinConf: Map[String, String] = Map( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") + + val broadcastHashJoinConf: Map[String, String] = Map( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760") + + val shuffledHashJoinConf: Map[String, String] = Map( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true") + + val allJoinConfCombinations: Seq[Map[String, String]] = + Seq(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf) + + val joinConfs: Seq[Map[String, String]] = if (regenGoldenFiles) { + require( + !sys.env.contains("SPARK_TPCDS_JOIN_CONF"), + "'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'") + Seq(sortMergeJoinConf) + } else { + sys.env + .get("SPARK_TPCDS_JOIN_CONF") + .map { s => + val p = new java.util.Properties() + p.load(new java.io.StringReader(s)) + Seq(p.asScala.toMap) + } + .getOrElse(allJoinConfCombinations) + } + + assert(joinConfs.nonEmpty) + joinConfs.foreach(conf => + require( + allJoinConfCombinations.contains(conf), + s"Join configurations [$conf] should be one of $allJoinConfCombinations")) + + if (tpcdsDataPath.nonEmpty) { + tpcdsQueries.foreach { name => + val queryString = resourceToString( + s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out") + joinConfs.foreach { conf => + System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368 + runQuery(queryString, goldenFile, conf) + } + } + } + + tpcdsQueriesV2_7_0.foreach { name => + val queryString = resourceToString( + s"tpcds-v2.7.0/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"$name-v2.7") { + val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out") + joinConfs.foreach { conf => + System.gc() // SPARK-37368 + runQuery(queryString, goldenFile, conf) + } + } + } + } else { + ignore("skipped because env `SPARK_TPCDS_DATA` is not set") {} + } +}