Skip to content

Commit

Permalink
[SPARK-50441][SQL] Fix parametrized identifiers not working when refe…
Browse files Browse the repository at this point in the history
…rencing CTEs

### What changes were proposed in this pull request?
Fix parametrized identifiers not working when referencing CTEs

### Why are the changes needed?
For a query:

`with t1 as (select 1) select * from identifier(:cte) using cte as "t1"`

the resolution fails because `BindParameters` can't resolve parameters because it waits for `ResolveIdentifierClause` to resolve `UnresolvedWithCTERelation`, but `ResolveIdentifierClause` can't resolve `UnresolvedWithCTERelation` until all `NamedParameters` in the plan are resolved.

Instead of delaying CTE resolution with `UnresolvedWithCTERelation`, we can remove node entirely and delay the resolution by keeping the original `PlanWithUnresolvedIdentifier` and moving the CTE resolution to its `planBuilder`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added a new test to `ParametersSuite`

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #48994 from mihailotim-db/mihailotim-db/cte_identifer.

Authored-by: Mihailo Timotic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
mihailotim-db authored and cloud-fan committed Nov 29, 2024
1 parent 4b97e11 commit 3fab712
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1610,9 +1610,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case s: Sort if !s.resolved || s.missingInput.nonEmpty =>
resolveReferencesInSort(s)

case u: UnresolvedWithCTERelations =>
UnresolvedWithCTERelations(this.apply(u.unresolvedPlan), u.cteRelations)

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}")
q.mapExpressions(resolveExpressionByPlanChildren(_, q, includeLastResort = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,25 @@ object CTESubstitution extends Rule[LogicalPlan] {
resolvedCTERelations
}

private def resolveWithCTERelations(
table: String,
alwaysInline: Boolean,
cteRelations: Seq[(String, CTERelationDef)],
unresolvedRelation: UnresolvedRelation): LogicalPlan = {
cteRelations
.find(r => conf.resolver(r._1, table))
.map {
case (_, d) =>
if (alwaysInline) {
d.child
} else {
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}
}
.getOrElse(unresolvedRelation)
}

private def substituteCTE(
plan: LogicalPlan,
alwaysInline: Boolean,
Expand All @@ -279,22 +298,20 @@ object CTESubstitution extends Rule[LogicalPlan] {
throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table))

case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) =>
if (alwaysInline) {
d.child
} else {
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}
}.getOrElse(u)
resolveWithCTERelations(table, alwaysInline, cteRelations, u)

case p: PlanWithUnresolvedIdentifier =>
// We must look up CTE relations first when resolving `UnresolvedRelation`s,
// but we can't do it here as `PlanWithUnresolvedIdentifier` is a leaf node
// and may produce `UnresolvedRelation` later.
// Here we wrap it with `UnresolvedWithCTERelations` so that we can
// delay the CTE relations lookup after `PlanWithUnresolvedIdentifier` is resolved.
UnresolvedWithCTERelations(p, cteRelations)
// and may produce `UnresolvedRelation` later. Instead, we delay CTE resolution
// by moving it to the planBuilder of the corresponding `PlanWithUnresolvedIdentifier`.
p.copy(planBuilder = (nameParts, children) => {
p.planBuilder.apply(nameParts, children) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
resolveWithCTERelations(table, alwaysInline, cteRelations, u)
case other => other
}
})

case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AliasHelper, EvalHelper, Expression}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationRef, LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE}
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_IDENTIFIER
import org.apache.spark.sql.types.StringType

/**
Expand All @@ -30,18 +30,9 @@ import org.apache.spark.sql.types.StringType
object ResolveIdentifierClause extends Rule[LogicalPlan] with AliasHelper with EvalHelper {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(UNRESOLVED_IDENTIFIER, UNRESOLVED_IDENTIFIER_WITH_CTE)) {
_.containsPattern(UNRESOLVED_IDENTIFIER)) {
case p: PlanWithUnresolvedIdentifier if p.identifierExpr.resolved && p.childrenResolved =>
p.planBuilder.apply(evalIdentifierExpr(p.identifierExpr), p.children)
case u @ UnresolvedWithCTERelations(p, cteRelations) =>
this.apply(p) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) =>
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}.getOrElse(u)
case other => other
}
case other =>
other.transformExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_IDENTIFIER)) {
case e: ExpressionWithUnresolvedIdentifier if e.identifierExpr.resolved =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, LeafExpression, Literal, MapFromArrays, MapFromEntries, SubqueryExpression, Unevaluable, VariableReference}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SupervisingCommand}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_IDENTIFIER_WITH_CTE, UNRESOLVED_WITH}
import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMAND, PARAMETER, PARAMETERIZED_QUERY, TreePattern, UNRESOLVED_WITH}
import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -189,7 +189,7 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase {
// We should wait for `CTESubstitution` to resolve CTE before binding parameters, as CTE
// relations are not children of `UnresolvedWith`.
case NameParameterizedQuery(child, argNames, argValues)
if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) &&
if !child.containsPattern(UNRESOLVED_WITH) &&
argValues.forall(_.resolved) =>
if (argNames.length != argValues.length) {
throw SparkException.internalError(s"The number of argument names ${argNames.length} " +
Expand All @@ -200,7 +200,7 @@ object BindParameters extends ParameterizedQueryProcessor with QueryErrorsBase {
bind(child) { case NamedParameter(name) if args.contains(name) => args(name) }

case PosParameterizedQuery(child, args)
if !child.containsAnyPattern(UNRESOLVED_WITH, UNRESOLVED_IDENTIFIER_WITH_CTE) &&
if !child.containsPattern(UNRESOLVED_WITH) &&
args.forall(_.resolved) =>
val indexedArgs = args.zipWithIndex
checkArgs(indexedArgs.map(arg => (s"_${arg._2}", arg._1)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand Down Expand Up @@ -76,17 +76,6 @@ case class PlanWithUnresolvedIdentifier(
copy(identifierExpr, newChildren, planBuilder)
}

/**
* A logical plan placeholder which delays CTE resolution
* to moment when PlanWithUnresolvedIdentifier gets resolved
*/
case class UnresolvedWithCTERelations(
unresolvedPlan: LogicalPlan,
cteRelations: Seq[(String, CTERelationDef)])
extends UnresolvedLeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_IDENTIFIER_WITH_CTE)
}

/**
* An expression placeholder that holds the identifier clause string expression. It will be
* replaced by the actual expression with the evaluated identifier string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ object TreePattern extends Enumeration {
val UNRESOLVED_FUNCTION: Value = Value
val UNRESOLVED_HINT: Value = Value
val UNRESOLVED_WINDOW_EXPRESSION: Value = Value
val UNRESOLVED_IDENTIFIER_WITH_CTE: Value = Value

// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -758,4 +758,15 @@ class ParametersSuite extends QueryTest with SharedSparkSession with PlanTest {
checkAnswer(spark.sql(query("?"), args = Array("tt1")), Row(1))
}
}

test("SPARK-50441: parameterized identifier referencing a CTE") {
def query(p: String): String = {
s"""
|WITH t1 AS (SELECT 1)
|SELECT * FROM IDENTIFIER($p)""".stripMargin
}

checkAnswer(spark.sql(query(":cte"), args = Map("cte" -> "t1")), Row(1))
checkAnswer(spark.sql(query("?"), args = Array("t1")), Row(1))
}
}

0 comments on commit 3fab712

Please sign in to comment.