From 0de3a5918788e6bf4ea0182eb7343e0331e6431b Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Tue, 5 Nov 2024 14:03:56 +0000 Subject: [PATCH] fix(spark): incorrect conversion of expand relation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In the expand relation, the projection expressions are stored in a two dimensional array. The spark matrix needs to be transposed in order to map the expressions into substrait, and vice-versa. I hadn’t noticed this earlier. Also, the remap field should not be used because the outputs are defined directly in the projections array. Signed-off-by: Andrew Coleman --- core/src/main/java/io/substrait/relation/Expand.java | 7 ++++--- .../src/main/scala/io/substrait/spark/SparkExtension.scala | 2 ++ .../scala/io/substrait/spark/logical/ToLogicalPlan.scala | 7 +++---- .../scala/io/substrait/spark/logical/ToSubstraitRel.scala | 3 +-- .../scala/io/substrait/spark/SubstraitPlanTestBase.scala | 6 +++++- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java index 9efff60af..7f88282ae 100644 --- a/core/src/main/java/io/substrait/relation/Expand.java +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -4,7 +4,6 @@ import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; -import java.util.stream.Stream; import org.immutables.value.Value; @Value.Enclosing @@ -18,7 +17,7 @@ public abstract class Expand extends SingleInputRel { public Type.Struct deriveRecordType() { Type.Struct initial = getInput().getRecordType(); return TypeCreator.of(initial.nullable()) - .struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64))); + .struct(getFields().stream().map(ExpandField::getType)); } @Override @@ -52,7 +51,9 @@ public abstract static class SwitchingField implements ExpandField { public abstract List getDuplicates(); public Type getType() { - return getDuplicates().get(0).getType(); + var nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable()); + var type = getDuplicates().get(0).getType(); + return nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type); } public static ImmutableExpand.SwitchingField.Builder builder() { diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index 53b5bfaaf..c470c7a42 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -34,6 +34,8 @@ object SparkExtension { private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection = SimpleExtension.loadDefaults() + val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls) + lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = { val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]() ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index daec2a5ed..68b15345a 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -277,14 +277,13 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } // An output column is nullable if any of the projections can assign null to it - val types = projections.transpose.map(p => (p.head.dataType, p.exists(_.nullable))) - - val output = types + val output = projections + .map(p => (p.head.dataType, p.exists(_.nullable))) .zip(names) .map { case (t, name) => StructField(name, t._1, t._2) } .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) - Expand(projections, output, child) + Expand(projections.transpose, output, child) } } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index b93eaecbe..2827a9c30 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -290,7 +290,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitExpand(p: Expand): relation.Rel = { - val fields = p.projections.map( + val fields = p.projections.transpose.map( proj => { relation.Expand.SwitchingField.builder .duplicates( @@ -302,7 +302,6 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { val names = p.output.map(_.name) relation.Expand.builder - .remap(relation.Rel.Remap.offset(p.child.output.size, names.size)) .fields(fields.asJava) .hint(Hint.builder.addAllOutputNames(names.asJava).build()) .input(visit(p.child)) diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index 4fa9ec263..cbd7a151c 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -26,7 +26,7 @@ import io.substrait.debug.TreePrinter import io.substrait.extension.ExtensionCollector import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter} import io.substrait.proto -import io.substrait.relation.RelProtoConverter +import io.substrait.relation.{ProtoRelConverter, RelProtoConverter} import org.scalactic.Equality import org.scalactic.source.Position import org.scalatest.Succeeded @@ -93,6 +93,10 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => require(logicalPlan2.resolved); val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) + val extensionCollector = new ExtensionCollector; + val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + pojoRel2.shouldEqualPlainly(pojoRel) logicalPlan2 }