From 6a585c0837b05c6d3e7ae3c887ebed8233cb0ccb Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 26 Nov 2024 22:50:58 -0800 Subject: [PATCH 01/17] feat: Replace Remap function variants with a single ChangeMapping interface. --- plan/builders.go | 17 ++++ plan/common.go | 4 + plan/plan.go | 10 +++ plan/plan_builder_test.go | 174 ++++++++++++++++++++++++++++++++++++-- plan/relations.go | 119 +++++++++++++++++++++++++- plan/relations_test.go | 6 ++ 6 files changed, 320 insertions(+), 10 deletions(-) diff --git a/plan/builders.go b/plan/builders.go index d969849..c2dbc47 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -71,13 +71,18 @@ type Builder interface { // from the output. Project(input Rel, exprs ...expr.Expression) (*ProjectRel, error) + // Deprecated: Use Project(...).ChangeMapping() instead. ProjectRemap(input Rel, remap []int32, exprs ...expr.Expression) (*ProjectRel, error) + // Deprecated: Use AggregateColumns(...).ChangeMapping() instead. AggregateColumnsRemap(input Rel, remap []int32, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error) AggregateColumns(input Rel, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error) + // Deprecated: Use AggregateExprs(...).ChangeMapping() instead. AggregateExprsRemap(input Rel, remap []int32, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error) AggregateExprs(input Rel, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error) + // Deprecated: Use CreateTableAsSelect(...).ChangeMapping() instead. CreateTableAsSelectRemap(input Rel, remap []int32, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) CreateTableAsSelect(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) + // Deprecated: Use Cross(...).ChangeMapping() instead. CrossRemap(left, right Rel, remap []int32) (*CrossRel, error) Cross(left, right Rel) (*CrossRel, error) // FetchRemap constructs a fetch relation providing an offset (skipping some @@ -85,19 +90,27 @@ type Builder interface { // rows). If count is FETCH_COUNT_ALL_RECORDS (-1) all records will be // returned. Similar to Fetch but allows for reordering and restricting the // returned columns. + // + // Deprecated: Use Fetch(...).ChangeMapping() instead. FetchRemap(input Rel, offset, count int64, remap []int32) (*FetchRel, error) // Fetch constructs a fetch relation providing an offset (skipping some number of // rows) and a count (restricting output to a maximum number of rows). If count // is FETCH_COUNT_ALL_RECORDS (-1) all records will be returned. Fetch(input Rel, offset, count int64) (*FetchRel, error) + // Deprecated: Use Filter(...).ChangeMapping() instead. FilterRemap(input Rel, condition expr.Expression, remap []int32) (*FilterRel, error) Filter(input Rel, condition expr.Expression) (*FilterRel, error) + // Deprecated: Use JoinAndFilter(...).ChangeMapping() instead. JoinAndFilterRemap(left, right Rel, condition, postJoinFilter expr.Expression, joinType JoinType, remap []int32) (*JoinRel, error) + // Deprecated: Use Fetch(...).ChangeMapping() instead. JoinAndFilter(left, right Rel, condition, postJoinFilter expr.Expression, joinType JoinType) (*JoinRel, error) + // Deprecated: Use Join(...).ChangeMapping() instead. JoinRemap(left, right Rel, condition expr.Expression, joinType JoinType, remap []int32) (*JoinRel, error) Join(left, right Rel, condition expr.Expression, joinType JoinType) (*JoinRel, error) + // Deprecated: Use NamedScan(...).ChangeMapping() instead. NamedScanRemap(tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableReadRel, error) NamedScan(tableName []string, schema types.NamedStruct) *NamedTableReadRel + // Deprecated: Use NamedWriteMap(...).ChangeMapping() instead. NamedWriteRemap(input Rel, op WriteOp, tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableWriteRel, error) // NamedInsert inserts data from the input relation into a named table. NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) @@ -105,12 +118,16 @@ type Builder interface { // provided input relation, which typically includes conditions that filter // the rows to delete. NamedDelete(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) + // Deprecated: Use VirtualTable(...).ChangeMapping() instead. VirtualTableRemap(fields []string, remap []int32, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error) VirtualTable(fields []string, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error) + // Deprecated: Use VirtualTableFromExpr(...).ChangeMapping() instead. VirtualTableFromExprRemap(fieldNames []string, remap []int32, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error) VirtualTableFromExpr(fieldNames []string, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error) + // Deprecated: Use Sort(...).ChangeMapping() instead. SortRemap(input Rel, remap []int32, sorts ...expr.SortField) (*SortRel, error) Sort(input Rel, sorts ...expr.SortField) (*SortRel, error) + // Deprecated: Use Set(...).ChangeMapping() instead. SetRemap(op SetOp, remap []int32, inputs ...Rel) (*SetRel, error) Set(op SetOp, inputs ...Rel) (*SetRel, error) diff --git a/plan/common.go b/plan/common.go index 3f77db8..1a9914f 100644 --- a/plan/common.go +++ b/plan/common.go @@ -49,6 +49,10 @@ func (rc *RelCommon) remap(initial types.RecordType) types.RecordType { func (rc *RelCommon) OutputMapping() []int32 { return rc.mapping } +func (rc *RelCommon) ClearMapping() { + rc.mapping = nil +} + func (rc *RelCommon) GetAdvancedExtension() *extensions.AdvancedExtension { return rc.advExtension } diff --git a/plan/plan.go b/plan/plan.go index 7bead89..14d6f4b 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -271,6 +271,16 @@ type Rel interface { // result should be 3 columns consisting of the 5th, 2nd and 1st // output columns from the underlying relation. OutputMapping() []int32 + // ClearMapping resets the mapping for this relation. + ClearMapping() + // ChangeMapping modifies the current relation by applying the provided + // mapping to the current relation. Typically used to remove any unneeded + // columns or provide them in a different order. If there already is a + // mapping on this relation, this provides mapping over the current mapping. + // + // If any column numbers specified are outside the currently available input + // range an error is returned and the mapping is left unchanged. + ChangeMapping(mapping []int32) error // directOutputSchema returns the output record type of the underlying // relation as a struct type. Mapping is not applied. directOutputSchema() types.RecordType diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 8afffc6..1703dc2 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -74,13 +74,27 @@ func TestBasicEmitPlan(t *testing.T) { func TestEmitEmptyPlan(t *testing.T) { b := plan.NewBuilderDefault() - root, err := b.NamedScanRemap([]string{"test"}, - baseSchema, []int32{}) + root := b.NamedScan([]string{"test"}, baseSchema) + err := root.ChangeMapping([]int32{}) + require.NoError(t, err) + _, err = b.Plan(root, []string{}) + require.NoError(t, err) + + b = plan.NewBuilderDefault() + root = b.NamedScan([]string{"test"}, baseSchema) + err = root.ChangeMapping([]int32{}) require.NoError(t, err) - p, err := b.Plan(root, []string{}) + _, err = b.Plan(root, []string{}) require.NoError(t, err) - assert.Equal(t, "NSTRUCT<>", p.GetRoots()[0].RecordType().String()) + b = plan.NewBuilderDefault() + root = b.NamedScan([]string{"test"}, baseSchema) + err = root.ChangeMapping([]int32{1, 0}) + require.NoError(t, err) + p, err := b.Plan(root, []string{"a", "b"}) + require.NoError(t, err) + + assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) protoPlan, err := p.ToProto() require.NoError(t, err) @@ -93,13 +107,29 @@ func TestEmitEmptyPlan(t *testing.T) { func TestBuildEmitOutOfRangePlan(t *testing.T) { b := plan.NewBuilderDefault() - root, err := b.NamedScanRemap([]string{"test"}, + _, err := b.NamedScanRemap([]string{"test"}, baseSchema, []int32{2}) - assert.Nil(t, root) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + + b = plan.NewBuilderDefault() + root := b.NamedScan([]string{"test"}, baseSchema) + err = root.ChangeMapping([]int32{2}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } +func TestMappingOfMapping(t *testing.T) { + b := plan.NewBuilderDefault() + ns := b.NamedScan([]string{"test"}, baseSchema) + err := ns.ChangeMapping([]int32{1, 0}) + assert.NoError(t, err) + assert.Equal(t, "struct", ns.RecordType().String()) + err = ns.ChangeMapping([]int32{1}) + assert.NoError(t, err) + assert.Equal(t, "struct", ns.RecordType().String()) +} + func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) { t.Helper() protoPlan, err := p.ToProto() @@ -261,20 +291,47 @@ func TestAggregateRelErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + acr, err := b.AggregateColumns(scan, nil, 0) + assert.NoError(t, err) + err = acr.ChangeMapping([]int32{-1, 5}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + ref, _ := b.RootFieldRef(scan, 0) _, err = b.AggregateExprsRemap(scan, []int32{5, -1}, nil, []expr.Expression{ref}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + ref, _ = b.RootFieldRef(scan, 0) + ae, err := b.AggregateExprs(scan, nil, []expr.Expression{ref}) + assert.NoError(t, err) + err = ae.ChangeMapping([]int32{5, -1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.AggregateExprsRemap(scan, []int32{1}, nil, []expr.Expression{ref}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) + assert.NoError(t, err) + err = ae.ChangeMapping([]int32{1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.AggregateExprsRemap(scan, []int32{0}, nil, []expr.Expression{ref}) assert.NoError(t, err) - _, err = b.AggregateColumnsRemap(scan, []int32{0}, nil, 0) + ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) + assert.NoError(t, err) + err = ae.ChangeMapping([]int32{0}) assert.NoError(t, err) + _, err = b.AggregateColumnsRemap(scan, []int32{0}, nil, 0) + assert.NoError(t, err) + ae, err = b.AggregateColumns(scan, nil, 0) + assert.NoError(t, err) + err = ae.ChangeMapping([]int32{0}) + assert.NoError(t, err) } func TestCrossRel(t *testing.T) { @@ -365,13 +422,31 @@ func TestCrossRelErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + c, err := b.Cross(left, right) + assert.NoError(t, err) + err = c.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.CrossRemap(left, right, []int32{5}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + c, err = b.Cross(left, right) + assert.NoError(t, err) + err = c.ChangeMapping([]int32{5}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + // Output is length 2 + 2 _, err = b.CrossRemap(left, right, []int32{2, 3}) assert.NoError(t, err) + + // Output is length 2 + 2 + c, err = b.Cross(left, right) + assert.NoError(t, err) + err = c.ChangeMapping([]int32{2, 3}) + assert.NoError(t, err) } func TestFetchRel(t *testing.T) { @@ -431,6 +506,9 @@ func TestFetchRel(t *testing.T) { assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) checkRoundTrip(t, expectedJSON, p) + + err = fetch.ChangeMapping([]int32{0}) + assert.NoError(t, err) } func TestFetchRelErrors(t *testing.T) { @@ -453,9 +531,22 @@ func TestFetchRelErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + f, err := b.Fetch(scan, 0, 0) + assert.NoError(t, err) + err = f.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.FetchRemap(scan, 0, 0, []int32{2}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + f, err = b.Fetch(scan, 0, 0) + assert.NoError(t, err) + err = f.ChangeMapping([]int32{2}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + } func TestFilterRelation(t *testing.T) { @@ -513,6 +604,9 @@ func TestFilterRelation(t *testing.T) { assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) checkRoundTrip(t, expectedJSON, p) + + err = filter.ChangeMapping([]int32{0}) + assert.NoError(t, err) } func TestFilterRelationErrors(t *testing.T) { @@ -547,9 +641,21 @@ func TestFilterRelationErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + f, err := b.Filter(scan, refBool) + assert.NoError(t, err) + err = f.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.FilterRemap(scan, refBool, []int32{3}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + f, err = b.Filter(scan, refBool) + assert.NoError(t, err) + err = f.ChangeMapping([]int32{3}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") } func TestJoinRelOutputRecordTypes(t *testing.T) { @@ -764,10 +870,22 @@ func TestJoinRelationError(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + j, err := b.Join(left, right, goodcond, plan.JoinTypeInner) + assert.NoError(t, err) + err = j.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.JoinRemap(left, right, goodcond, plan.JoinTypeLeftAnti, []int32{2}) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + j, err = b.Join(left, right, goodcond, plan.JoinTypeLeftAnti) + assert.NoError(t, err) + err = j.ChangeMapping([]int32{2}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.JoinAndFilter(left, right, goodcond, badcond, plan.JoinTypeInner) assert.ErrorIs(t, err, substraitgo.ErrInvalidArg) assert.ErrorContains(t, err, "post join filter must be either nil or yield a boolean, not string") @@ -992,6 +1110,13 @@ func TestSortRelationErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + fields, _ = b.SortFields(scan, 1, 0) + s, err := b.Sort(scan, fields...) + assert.NoError(t, err) + err = s.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.Sort(nil) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "input Relation must not be nil") @@ -1003,6 +1128,12 @@ func TestSortRelationErrors(t *testing.T) { _, err = b.SortRemap(scan, []int32{3}, fields...) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + sortRel, err := b.Sort(scan, fields...) + assert.NoError(t, err) + err = sortRel.ChangeMapping([]int32{3}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") } func TestProjectExpressions(t *testing.T) { @@ -1300,12 +1431,29 @@ func TestProjectErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + p, err := b.Project(scan, ref) + assert.NoError(t, err) + err = p.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.ProjectRemap(scan, []int32{3}, ref) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + p, err = b.Project(scan, ref) + assert.NoError(t, err) + err = p.ChangeMapping([]int32{3}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.ProjectRemap(scan, []int32{2}, ref) assert.NoError(t, err, "Expected expression mapping to be in-bounds") + + p, err = b.Project(scan, ref) + assert.NoError(t, err) + err = p.ChangeMapping([]int32{2}) + assert.NoError(t, err, "Expected expression mapping to be in-bounds") } func TestSetRelations(t *testing.T) { @@ -1514,7 +1662,19 @@ func TestSetRelErrors(t *testing.T) { assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + s, err := b.Set(plan.SetOpMinusMultiset, scan1, scan2) + assert.NoError(t, err) + err = s.ChangeMapping([]int32{-1}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") + _, err = b.SetRemap(plan.SetOpMinusMultiset, []int32{3}, scan1, scan2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") + + s, err = b.Set(plan.SetOpMinusMultiset, scan1, scan2) + assert.NoError(t, err) + err = s.ChangeMapping([]int32{3}) + assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) + assert.ErrorContains(t, err, "output mapping index out of range") } diff --git a/plan/relations.go b/plan/relations.go index 6a716bd..163b3b7 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -152,6 +152,35 @@ func (b *baseReadRel) updateFilters(filters []expr.Expression) { b.filter, b.bestEffortFilter = filters[0], filters[1] } +type MappableRel interface { + RecordType() types.RecordType + OutputMapping() []int32 +} + +func ChangeMapping(r MappableRel, mapping []int32) ([]int32, error) { + nOutput := r.RecordType().FieldCount() + oldMapping := r.OutputMapping() + newMapping := make([]int32, 0, len(mapping)) + for _, idx := range mapping { + if idx < 0 || idx >= nOutput { + return nil, errOutputMappingOutOfRange + } + if len(oldMapping) > 0 { + newMapping = append(newMapping, oldMapping[idx]) + } + } + if len(oldMapping) > 0 { + return newMapping, nil + } + return mapping, nil +} + +func (b *baseReadRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(b, mapping) + b.mapping = newMapping + return err +} + // NamedTableReadRel is a named scan of a base table. The list of strings // that make up the names are to represent namespacing (e.g. mydb.mytable). // This assumes a shared catalog between systems exchanging a message. @@ -589,6 +618,12 @@ func (p *ProjectRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &proj, nil } +func (p *ProjectRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(p, mapping) + p.mapping = newMapping + return err +} + var defFilter = expr.NewPrimitiveLiteral(true, false) type JoinType = proto.JoinRel_JoinType @@ -741,6 +776,12 @@ func (j *JoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . return &join, nil } +func (j *JoinRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(j, mapping) + j.mapping = newMapping + return err +} + // CrossRel is a cartesian product relational operator of two tables. type CrossRel struct { RelCommon @@ -802,6 +843,12 @@ func (c *CrossRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return c.Copy(newInputs...) } +func (c *CrossRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(c, mapping) + c.mapping = newMapping + return err +} + // FetchRel is a relational operator representing LIMIT/OFFSET or // TOP type semantics. type FetchRel struct { @@ -868,6 +915,12 @@ func (f *FetchRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return f.Copy(newInputs...) } +func (f *FetchRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(f, mapping) + f.mapping = newMapping + return err +} + type AggRelMeasure struct { measure *expr.AggregateFunction filter expr.Expression @@ -1016,6 +1069,12 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &aggregate, nil } +func (ar *AggregateRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(ar, mapping) + ar.mapping = newMapping + return err +} + // SortRel is an ORDER BY relational operator, describing a base relation, // it includes a list of fields to sort on. type SortRel struct { @@ -1097,6 +1156,12 @@ func (sr *SortRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs return &sort, nil } +func (sr *SortRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(sr, mapping) + sr.mapping = newMapping + return err +} + // FilterRel is a relational operator capturing simple filters ( // as in the WHERE clause of a SQL query). type FilterRel struct { @@ -1168,6 +1233,12 @@ func (fr *FilterRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &filter, nil } +func (fr *FilterRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(fr, mapping) + fr.mapping = newMapping + return err +} + type SetOp = proto.SetRel_SetOp const ( @@ -1241,6 +1312,12 @@ func (s *SetRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (Rel return s.Copy(newInputs...) } +func (s *SetRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(s, mapping) + s.mapping = newMapping + return err +} + // ExtensionSingleRel is a stub to support extensions with a single input. type ExtensionSingleRel struct { RelCommon @@ -1298,6 +1375,12 @@ func (es *ExtensionSingleRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return es.Copy(newInputs...) } +func (es *ExtensionSingleRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(es, mapping) + es.mapping = newMapping + return err +} + // ExtensionLeafRel is a stub to support extensions with zero inputs. type ExtensionLeafRel struct { RelCommon @@ -1342,6 +1425,12 @@ func (el *ExtensionLeafRel) CopyWithExpressionRewrite(_ RewriteFunc, _ ...Rel) ( return el, nil } +func (el *ExtensionLeafRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(el, mapping) + el.mapping = newMapping + return err +} + // ExtensionMultiRel is a stub to support extensions with multiple inputs. type ExtensionMultiRel struct { RelCommon @@ -1398,6 +1487,12 @@ func (em *ExtensionMultiRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return em.Copy(newInputs...) } +func (em *ExtensionMultiRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(em, mapping) + em.mapping = newMapping + return err +} + type HashMergeJoinType int8 const ( @@ -1516,6 +1611,12 @@ func (hr *HashJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInp return &join, nil } +func (hr *HashJoinRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(hr, mapping) + hr.mapping = newMapping + return err +} + // MergeJoinRel represents a join done by taking advantage of two sets // that are sorted on the join keys. This allows the join operation to // be done in a streaming fashion. @@ -1620,6 +1721,12 @@ func (mr *MergeJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &merge, nil } +func (mr *MergeJoinRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(mr, mapping) + mr.mapping = newMapping + return err +} + type WriteOp = proto.WriteRel_WriteOp const ( @@ -1670,9 +1777,9 @@ func (wr *NamedTableWriteRel) RecordType() types.RecordType { return wr.remap(wr.directOutputSchema()) } -func (n *NamedTableWriteRel) Names() []string { return n.names } -func (n *NamedTableWriteRel) NamedTableAdvancedExtension() *extensions.AdvancedExtension { - return n.advExtension +func (wr *NamedTableWriteRel) Names() []string { return wr.names } +func (wr *NamedTableWriteRel) NamedTableAdvancedExtension() *extensions.AdvancedExtension { + return wr.advExtension } func (wr *NamedTableWriteRel) TableSchema() types.NamedStruct { @@ -1731,6 +1838,12 @@ func (wr *NamedTableWriteRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return wr.Copy(newInputs...) } +func (wr *NamedTableWriteRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(wr, mapping) + wr.mapping = newMapping + return err +} + var ( _ Rel = (*NamedTableReadRel)(nil) _ Rel = (*VirtualTableReadRel)(nil) diff --git a/plan/relations_test.go b/plan/relations_test.go index 57bc8ca..ed5929a 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -382,6 +382,12 @@ func (f *fakeRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . panic("unused") } +func (f *fakeRel) ChangeMapping(mapping []int32) error { + newMapping, err := ChangeMapping(f, mapping) + f.mapping = newMapping + return err +} + func TestProjectRecordType(t *testing.T) { var rel ProjectRel rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( From 3ebd6d3abc9a0905ca1d6fa1fe11f37883663aa8 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 27 Nov 2024 16:38:13 -0800 Subject: [PATCH 02/17] added more tests --- plan/relations_test.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/plan/relations_test.go b/plan/relations_test.go index ed5929a..79158a2 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -393,12 +393,13 @@ func TestProjectRecordType(t *testing.T) { rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}})} - rel.mapping = nil + rel.ClearMapping() expected := *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}, &types.Int64Type{}}) result := rel.RecordType() assert.Equal(t, expected, result) - rel.mapping = []int32{0} + err := rel.ChangeMapping([]int32{0}) + assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() assert.Equal(t, expected, result) @@ -409,12 +410,13 @@ func TestExtensionSingleRecordType(t *testing.T) { rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}})} - rel.mapping = nil + rel.ClearMapping() expected := *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}, &types.Int64Type{}}) result := rel.RecordType() assert.Equal(t, expected, result) - rel.mapping = []int32{0} + err := rel.ChangeMapping([]int32{0}) + assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() assert.Equal(t, expected, result) @@ -427,13 +429,14 @@ func TestHashJoinRecordType(t *testing.T) { rel.right = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.StringType{}, &types.StringType{}})} - rel.mapping = nil + rel.ClearMapping() expected := *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}, &types.StringType{}, &types.StringType{}}) result := rel.RecordType() assert.Equal(t, expected, result) - rel.mapping = []int32{0} + err := rel.ChangeMapping([]int32{0}) + assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() assert.Equal(t, expected, result) @@ -446,13 +449,14 @@ func TestMergeJoinRecordType(t *testing.T) { rel.right = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.StringType{}, &types.StringType{}})} - rel.mapping = nil + rel.ClearMapping() expected := *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}, &types.StringType{}, &types.StringType{}}) result := rel.RecordType() assert.Equal(t, expected, result) - rel.mapping = []int32{0} + err := rel.ChangeMapping([]int32{0}) + assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}}) result = rel.RecordType() From 771ccecb0a8bf36591316881c1addeb6001c7526 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 27 Nov 2024 16:43:46 -0800 Subject: [PATCH 03/17] added tests for uncommon relations --- plan/relations_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/plan/relations_test.go b/plan/relations_test.go index 79158a2..c01e1ab 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -422,6 +422,30 @@ func TestExtensionSingleRecordType(t *testing.T) { assert.Equal(t, expected, result) } +func TestExtensionLeafRecordType(t *testing.T) { + var rel ExtensionLeafRel + + rel.ClearMapping() + expected := *types.NewRecordTypeFromTypes(nil) + result := rel.RecordType() + assert.Equal(t, expected, result) + + err := rel.ChangeMapping([]int32{0}) + assert.ErrorContains(t, err, "output mapping index out of range") +} + +func TestExtensionMultiRecordType(t *testing.T) { + var rel ExtensionMultiRel + + rel.ClearMapping() + expected := *types.NewRecordTypeFromTypes(nil) + result := rel.RecordType() + assert.Equal(t, expected, result) + + err := rel.ChangeMapping([]int32{0}) + assert.ErrorContains(t, err, "output mapping index out of range") +} + func TestHashJoinRecordType(t *testing.T) { var rel HashJoinRel rel.left = &fakeRel{outputType: *types.NewRecordTypeFromTypes( From 6587591cdc08f3bcfa6b7d945df4c634d9b43dad Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 27 Nov 2024 16:51:36 -0800 Subject: [PATCH 04/17] and NamedTableWriteRel --- plan/relations_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/plan/relations_test.go b/plan/relations_test.go index c01e1ab..c3db33e 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -486,3 +486,18 @@ func TestMergeJoinRecordType(t *testing.T) { result = rel.RecordType() assert.Equal(t, expected, result) } + +func TestNamedTableWriteRecordType(t *testing.T) { + var rel NamedTableWriteRel + rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( + []types.Type{&types.Int64Type{}, &types.StringType{}})} + rel.outputMode = proto.WriteRel_OUTPUT_MODE_MODIFIED_RECORDS + + rel.ClearMapping() + expected := *types.NewRecordTypeFromTypes(nil) + result := rel.RecordType() + assert.Equal(t, expected, result) + + err := rel.ChangeMapping([]int32{0}) + assert.ErrorContains(t, err, "output mapping index out of range") +} From 24a310ff8862821292a14693c3a881821b4a12ab Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 09:11:32 -0800 Subject: [PATCH 05/17] make OutputMapping not point to the internal copy --- plan/common.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plan/common.go b/plan/common.go index 1a9914f..8226483 100644 --- a/plan/common.go +++ b/plan/common.go @@ -47,7 +47,12 @@ func (rc *RelCommon) remap(initial types.RecordType) types.RecordType { return *types.NewRecordTypeFromTypes(outTypes) } -func (rc *RelCommon) OutputMapping() []int32 { return rc.mapping } +func (rc *RelCommon) OutputMapping() []int32 { + // Make a copy of the output mapping to prevent accidental modification. + mapCopy := make([]int32, len(rc.mapping)) + copy(mapCopy, rc.mapping) + return mapCopy +} func (rc *RelCommon) ClearMapping() { rc.mapping = nil From 42125cd9d09e412e10395cd4e1251814f81a713f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 15:58:36 -0800 Subject: [PATCH 06/17] added a test for mapping preservation upon an error --- plan/plan_builder_test.go | 5 +++++ plan/relations.go | 4 +++- plan/relations_test.go | 4 +--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 1703dc2..0adde9a 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -96,6 +96,11 @@ func TestEmitEmptyPlan(t *testing.T) { assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) + // Verify the mapping remains the same after receiving an error. + err = root.ChangeMapping([]int32{-1}) + require.Error(t, err) + assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) + protoPlan, err := p.ToProto() require.NoError(t, err) diff --git a/plan/relations.go b/plan/relations.go index 163b3b7..e9e3933 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -157,13 +157,15 @@ type MappableRel interface { OutputMapping() []int32 } +// ChangeMapping implements the core functionality of ChangeMapping for relations. +// It returns the relation's existing mapping on an error to ease being called. func ChangeMapping(r MappableRel, mapping []int32) ([]int32, error) { nOutput := r.RecordType().FieldCount() oldMapping := r.OutputMapping() newMapping := make([]int32, 0, len(mapping)) for _, idx := range mapping { if idx < 0 || idx >= nOutput { - return nil, errOutputMappingOutOfRange + return r.OutputMapping(), errOutputMappingOutOfRange } if len(oldMapping) > 0 { newMapping = append(newMapping, oldMapping[idx]) diff --git a/plan/relations_test.go b/plan/relations_test.go index c3db33e..fea10c0 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -383,9 +383,7 @@ func (f *fakeRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . } func (f *fakeRel) ChangeMapping(mapping []int32) error { - newMapping, err := ChangeMapping(f, mapping) - f.mapping = newMapping - return err + panic("unused") } func TestProjectRecordType(t *testing.T) { From 91c1cf1fc5236fc2af6bf5e7db459750d1670623 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 17:14:26 -0800 Subject: [PATCH 07/17] change ChangeMapping to use varargs instead of a slice --- plan/plan.go | 2 +- plan/plan_builder_test.go | 62 +++++++++++++++++++-------------------- plan/relations.go | 33 +++++++++++---------- plan/relations_test.go | 16 +++++----- 4 files changed, 58 insertions(+), 55 deletions(-) diff --git a/plan/plan.go b/plan/plan.go index 14d6f4b..f207df7 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -280,7 +280,7 @@ type Rel interface { // // If any column numbers specified are outside the currently available input // range an error is returned and the mapping is left unchanged. - ChangeMapping(mapping []int32) error + ChangeMapping(mapping ...int32) error // directOutputSchema returns the output record type of the underlying // relation as a struct type. Mapping is not applied. directOutputSchema() types.RecordType diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 0adde9a..f3f9ff3 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -74,22 +74,22 @@ func TestBasicEmitPlan(t *testing.T) { func TestEmitEmptyPlan(t *testing.T) { b := plan.NewBuilderDefault() - root := b.NamedScan([]string{"test"}, baseSchema) - err := root.ChangeMapping([]int32{}) + root, err := b.NamedScanRemap([]string{"test"}, + baseSchema, []int32{}) require.NoError(t, err) _, err = b.Plan(root, []string{}) require.NoError(t, err) b = plan.NewBuilderDefault() root = b.NamedScan([]string{"test"}, baseSchema) - err = root.ChangeMapping([]int32{}) + err = root.ChangeMapping() require.NoError(t, err) _, err = b.Plan(root, []string{}) require.NoError(t, err) b = plan.NewBuilderDefault() root = b.NamedScan([]string{"test"}, baseSchema) - err = root.ChangeMapping([]int32{1, 0}) + err = root.ChangeMapping(1, 0) require.NoError(t, err) p, err := b.Plan(root, []string{"a", "b"}) require.NoError(t, err) @@ -97,7 +97,7 @@ func TestEmitEmptyPlan(t *testing.T) { assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) // Verify the mapping remains the same after receiving an error. - err = root.ChangeMapping([]int32{-1}) + err = root.ChangeMapping(-1) require.Error(t, err) assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) @@ -119,7 +119,7 @@ func TestBuildEmitOutOfRangePlan(t *testing.T) { b = plan.NewBuilderDefault() root := b.NamedScan([]string{"test"}, baseSchema) - err = root.ChangeMapping([]int32{2}) + err = root.ChangeMapping(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -127,10 +127,10 @@ func TestBuildEmitOutOfRangePlan(t *testing.T) { func TestMappingOfMapping(t *testing.T) { b := plan.NewBuilderDefault() ns := b.NamedScan([]string{"test"}, baseSchema) - err := ns.ChangeMapping([]int32{1, 0}) + err := ns.ChangeMapping(1, 0) assert.NoError(t, err) assert.Equal(t, "struct", ns.RecordType().String()) - err = ns.ChangeMapping([]int32{1}) + err = ns.ChangeMapping(1) assert.NoError(t, err) assert.Equal(t, "struct", ns.RecordType().String()) } @@ -298,7 +298,7 @@ func TestAggregateRelErrors(t *testing.T) { acr, err := b.AggregateColumns(scan, nil, 0) assert.NoError(t, err) - err = acr.ChangeMapping([]int32{-1, 5}) + err = acr.ChangeMapping(-1, 5) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -310,7 +310,7 @@ func TestAggregateRelErrors(t *testing.T) { ref, _ = b.RootFieldRef(scan, 0) ae, err := b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.ChangeMapping([]int32{5, -1}) + err = ae.ChangeMapping(5, -1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -320,7 +320,7 @@ func TestAggregateRelErrors(t *testing.T) { ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.ChangeMapping([]int32{1}) + err = ae.ChangeMapping(1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -328,14 +328,14 @@ func TestAggregateRelErrors(t *testing.T) { assert.NoError(t, err) ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.ChangeMapping([]int32{0}) + err = ae.ChangeMapping(0) assert.NoError(t, err) _, err = b.AggregateColumnsRemap(scan, []int32{0}, nil, 0) assert.NoError(t, err) ae, err = b.AggregateColumns(scan, nil, 0) assert.NoError(t, err) - err = ae.ChangeMapping([]int32{0}) + err = ae.ChangeMapping(0) assert.NoError(t, err) } @@ -429,7 +429,7 @@ func TestCrossRelErrors(t *testing.T) { c, err := b.Cross(left, right) assert.NoError(t, err) - err = c.ChangeMapping([]int32{-1}) + err = c.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -439,7 +439,7 @@ func TestCrossRelErrors(t *testing.T) { c, err = b.Cross(left, right) assert.NoError(t, err) - err = c.ChangeMapping([]int32{5}) + err = c.ChangeMapping(5) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -450,7 +450,7 @@ func TestCrossRelErrors(t *testing.T) { // Output is length 2 + 2 c, err = b.Cross(left, right) assert.NoError(t, err) - err = c.ChangeMapping([]int32{2, 3}) + err = c.ChangeMapping(2, 3) assert.NoError(t, err) } @@ -512,7 +512,7 @@ func TestFetchRel(t *testing.T) { checkRoundTrip(t, expectedJSON, p) - err = fetch.ChangeMapping([]int32{0}) + err = fetch.ChangeMapping(0) assert.NoError(t, err) } @@ -538,7 +538,7 @@ func TestFetchRelErrors(t *testing.T) { f, err := b.Fetch(scan, 0, 0) assert.NoError(t, err) - err = f.ChangeMapping([]int32{-1}) + err = f.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -548,7 +548,7 @@ func TestFetchRelErrors(t *testing.T) { f, err = b.Fetch(scan, 0, 0) assert.NoError(t, err) - err = f.ChangeMapping([]int32{2}) + err = f.ChangeMapping(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -610,7 +610,7 @@ func TestFilterRelation(t *testing.T) { checkRoundTrip(t, expectedJSON, p) - err = filter.ChangeMapping([]int32{0}) + err = filter.ChangeMapping(0) assert.NoError(t, err) } @@ -648,7 +648,7 @@ func TestFilterRelationErrors(t *testing.T) { f, err := b.Filter(scan, refBool) assert.NoError(t, err) - err = f.ChangeMapping([]int32{-1}) + err = f.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -658,7 +658,7 @@ func TestFilterRelationErrors(t *testing.T) { f, err = b.Filter(scan, refBool) assert.NoError(t, err) - err = f.ChangeMapping([]int32{3}) + err = f.ChangeMapping(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -877,7 +877,7 @@ func TestJoinRelationError(t *testing.T) { j, err := b.Join(left, right, goodcond, plan.JoinTypeInner) assert.NoError(t, err) - err = j.ChangeMapping([]int32{-1}) + err = j.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -887,7 +887,7 @@ func TestJoinRelationError(t *testing.T) { j, err = b.Join(left, right, goodcond, plan.JoinTypeLeftAnti) assert.NoError(t, err) - err = j.ChangeMapping([]int32{2}) + err = j.ChangeMapping(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1118,7 +1118,7 @@ func TestSortRelationErrors(t *testing.T) { fields, _ = b.SortFields(scan, 1, 0) s, err := b.Sort(scan, fields...) assert.NoError(t, err) - err = s.ChangeMapping([]int32{-1}) + err = s.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1136,7 +1136,7 @@ func TestSortRelationErrors(t *testing.T) { sortRel, err := b.Sort(scan, fields...) assert.NoError(t, err) - err = sortRel.ChangeMapping([]int32{3}) + err = sortRel.ChangeMapping(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -1438,7 +1438,7 @@ func TestProjectErrors(t *testing.T) { p, err := b.Project(scan, ref) assert.NoError(t, err) - err = p.ChangeMapping([]int32{-1}) + err = p.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1448,7 +1448,7 @@ func TestProjectErrors(t *testing.T) { p, err = b.Project(scan, ref) assert.NoError(t, err) - err = p.ChangeMapping([]int32{3}) + err = p.ChangeMapping(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1457,7 +1457,7 @@ func TestProjectErrors(t *testing.T) { p, err = b.Project(scan, ref) assert.NoError(t, err) - err = p.ChangeMapping([]int32{2}) + err = p.ChangeMapping(2) assert.NoError(t, err, "Expected expression mapping to be in-bounds") } @@ -1669,7 +1669,7 @@ func TestSetRelErrors(t *testing.T) { s, err := b.Set(plan.SetOpMinusMultiset, scan1, scan2) assert.NoError(t, err) - err = s.ChangeMapping([]int32{-1}) + err = s.ChangeMapping(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1679,7 +1679,7 @@ func TestSetRelErrors(t *testing.T) { s, err = b.Set(plan.SetOpMinusMultiset, scan1, scan2) assert.NoError(t, err) - err = s.ChangeMapping([]int32{3}) + err = s.ChangeMapping(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } diff --git a/plan/relations.go b/plan/relations.go index e9e3933..2726276 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -160,6 +160,9 @@ type MappableRel interface { // ChangeMapping implements the core functionality of ChangeMapping for relations. // It returns the relation's existing mapping on an error to ease being called. func ChangeMapping(r MappableRel, mapping []int32) ([]int32, error) { + if len(mapping) == 0 { + return []int32{}, nil + } nOutput := r.RecordType().FieldCount() oldMapping := r.OutputMapping() newMapping := make([]int32, 0, len(mapping)) @@ -177,7 +180,7 @@ func ChangeMapping(r MappableRel, mapping []int32) ([]int32, error) { return mapping, nil } -func (b *baseReadRel) ChangeMapping(mapping []int32) error { +func (b *baseReadRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(b, mapping) b.mapping = newMapping return err @@ -620,7 +623,7 @@ func (p *ProjectRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &proj, nil } -func (p *ProjectRel) ChangeMapping(mapping []int32) error { +func (p *ProjectRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(p, mapping) p.mapping = newMapping return err @@ -778,7 +781,7 @@ func (j *JoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . return &join, nil } -func (j *JoinRel) ChangeMapping(mapping []int32) error { +func (j *JoinRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(j, mapping) j.mapping = newMapping return err @@ -845,7 +848,7 @@ func (c *CrossRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return c.Copy(newInputs...) } -func (c *CrossRel) ChangeMapping(mapping []int32) error { +func (c *CrossRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(c, mapping) c.mapping = newMapping return err @@ -917,7 +920,7 @@ func (f *FetchRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return f.Copy(newInputs...) } -func (f *FetchRel) ChangeMapping(mapping []int32) error { +func (f *FetchRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(f, mapping) f.mapping = newMapping return err @@ -1071,7 +1074,7 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &aggregate, nil } -func (ar *AggregateRel) ChangeMapping(mapping []int32) error { +func (ar *AggregateRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(ar, mapping) ar.mapping = newMapping return err @@ -1158,7 +1161,7 @@ func (sr *SortRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs return &sort, nil } -func (sr *SortRel) ChangeMapping(mapping []int32) error { +func (sr *SortRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(sr, mapping) sr.mapping = newMapping return err @@ -1235,7 +1238,7 @@ func (fr *FilterRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &filter, nil } -func (fr *FilterRel) ChangeMapping(mapping []int32) error { +func (fr *FilterRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(fr, mapping) fr.mapping = newMapping return err @@ -1314,7 +1317,7 @@ func (s *SetRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (Rel return s.Copy(newInputs...) } -func (s *SetRel) ChangeMapping(mapping []int32) error { +func (s *SetRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(s, mapping) s.mapping = newMapping return err @@ -1377,7 +1380,7 @@ func (es *ExtensionSingleRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return es.Copy(newInputs...) } -func (es *ExtensionSingleRel) ChangeMapping(mapping []int32) error { +func (es *ExtensionSingleRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(es, mapping) es.mapping = newMapping return err @@ -1427,7 +1430,7 @@ func (el *ExtensionLeafRel) CopyWithExpressionRewrite(_ RewriteFunc, _ ...Rel) ( return el, nil } -func (el *ExtensionLeafRel) ChangeMapping(mapping []int32) error { +func (el *ExtensionLeafRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(el, mapping) el.mapping = newMapping return err @@ -1489,7 +1492,7 @@ func (em *ExtensionMultiRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return em.Copy(newInputs...) } -func (em *ExtensionMultiRel) ChangeMapping(mapping []int32) error { +func (em *ExtensionMultiRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(em, mapping) em.mapping = newMapping return err @@ -1613,7 +1616,7 @@ func (hr *HashJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInp return &join, nil } -func (hr *HashJoinRel) ChangeMapping(mapping []int32) error { +func (hr *HashJoinRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(hr, mapping) hr.mapping = newMapping return err @@ -1723,7 +1726,7 @@ func (mr *MergeJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &merge, nil } -func (mr *MergeJoinRel) ChangeMapping(mapping []int32) error { +func (mr *MergeJoinRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(mr, mapping) mr.mapping = newMapping return err @@ -1840,7 +1843,7 @@ func (wr *NamedTableWriteRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return wr.Copy(newInputs...) } -func (wr *NamedTableWriteRel) ChangeMapping(mapping []int32) error { +func (wr *NamedTableWriteRel) ChangeMapping(mapping ...int32) error { newMapping, err := ChangeMapping(wr, mapping) wr.mapping = newMapping return err diff --git a/plan/relations_test.go b/plan/relations_test.go index fea10c0..e032b75 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -382,7 +382,7 @@ func (f *fakeRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . panic("unused") } -func (f *fakeRel) ChangeMapping(mapping []int32) error { +func (f *fakeRel) ChangeMapping(mapping ...int32) error { panic("unused") } @@ -396,7 +396,7 @@ func TestProjectRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() @@ -413,7 +413,7 @@ func TestExtensionSingleRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() @@ -428,7 +428,7 @@ func TestExtensionLeafRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -440,7 +440,7 @@ func TestExtensionMultiRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -457,7 +457,7 @@ func TestHashJoinRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() @@ -477,7 +477,7 @@ func TestMergeJoinRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}}) @@ -496,6 +496,6 @@ func TestNamedTableWriteRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping([]int32{0}) + err := rel.ChangeMapping(0) assert.ErrorContains(t, err, "output mapping index out of range") } From ec8c65418661cc55cd194c38f82df2b05ac46621 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 17:26:47 -0800 Subject: [PATCH 08/17] Remove ClearMapping. Note that ChangeMapping() is not a replacement. --- plan/common.go | 4 ---- plan/plan.go | 3 +-- plan/relations_test.go | 7 ------- 3 files changed, 1 insertion(+), 13 deletions(-) diff --git a/plan/common.go b/plan/common.go index 8226483..7cee1a9 100644 --- a/plan/common.go +++ b/plan/common.go @@ -54,10 +54,6 @@ func (rc *RelCommon) OutputMapping() []int32 { return mapCopy } -func (rc *RelCommon) ClearMapping() { - rc.mapping = nil -} - func (rc *RelCommon) GetAdvancedExtension() *extensions.AdvancedExtension { return rc.advExtension } diff --git a/plan/plan.go b/plan/plan.go index f207df7..5c7fd2a 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -271,8 +271,7 @@ type Rel interface { // result should be 3 columns consisting of the 5th, 2nd and 1st // output columns from the underlying relation. OutputMapping() []int32 - // ClearMapping resets the mapping for this relation. - ClearMapping() + // ChangeMapping modifies the current relation by applying the provided // mapping to the current relation. Typically used to remove any unneeded // columns or provide them in a different order. If there already is a diff --git a/plan/relations_test.go b/plan/relations_test.go index e032b75..2b1538a 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -391,7 +391,6 @@ func TestProjectRecordType(t *testing.T) { rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}})} - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}, &types.Int64Type{}}) result := rel.RecordType() assert.Equal(t, expected, result) @@ -408,7 +407,6 @@ func TestExtensionSingleRecordType(t *testing.T) { rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}})} - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}, &types.Int64Type{}}) result := rel.RecordType() assert.Equal(t, expected, result) @@ -423,7 +421,6 @@ func TestExtensionSingleRecordType(t *testing.T) { func TestExtensionLeafRecordType(t *testing.T) { var rel ExtensionLeafRel - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes(nil) result := rel.RecordType() assert.Equal(t, expected, result) @@ -435,7 +432,6 @@ func TestExtensionLeafRecordType(t *testing.T) { func TestExtensionMultiRecordType(t *testing.T) { var rel ExtensionMultiRel - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes(nil) result := rel.RecordType() assert.Equal(t, expected, result) @@ -451,7 +447,6 @@ func TestHashJoinRecordType(t *testing.T) { rel.right = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.StringType{}, &types.StringType{}})} - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}, &types.StringType{}, &types.StringType{}}) result := rel.RecordType() @@ -471,7 +466,6 @@ func TestMergeJoinRecordType(t *testing.T) { rel.right = &fakeRel{outputType: *types.NewRecordTypeFromTypes( []types.Type{&types.StringType{}, &types.StringType{}})} - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}, &types.Int64Type{}, &types.StringType{}, &types.StringType{}}) result := rel.RecordType() @@ -491,7 +485,6 @@ func TestNamedTableWriteRecordType(t *testing.T) { []types.Type{&types.Int64Type{}, &types.StringType{}})} rel.outputMode = proto.WriteRel_OUTPUT_MODE_MODIFIED_RECORDS - rel.ClearMapping() expected := *types.NewRecordTypeFromTypes(nil) result := rel.RecordType() assert.Equal(t, expected, result) From 116eac2d94b29ea2a5e83af6794ab11d20682109 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 17:37:47 -0800 Subject: [PATCH 09/17] Added NamedWrite to replace NamedWriteRemap (although NamedInsert and NamedDelete also exist). --- plan/builders.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/plan/builders.go b/plan/builders.go index c2dbc47..2586bd0 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -110,7 +110,7 @@ type Builder interface { // Deprecated: Use NamedScan(...).ChangeMapping() instead. NamedScanRemap(tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableReadRel, error) NamedScan(tableName []string, schema types.NamedStruct) *NamedTableReadRel - // Deprecated: Use NamedWriteMap(...).ChangeMapping() instead. + // Deprecated: Use NamedWrite(...).ChangeMapping() instead. NamedWriteRemap(input Rel, op WriteOp, tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableWriteRel, error) // NamedInsert inserts data from the input relation into a named table. NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) @@ -504,12 +504,16 @@ func (b *builder) NamedWriteRemap(input Rel, op WriteOp, tableName []string, sch }, nil } +func (b *builder) NamedWrite(input Rel, op WriteOp, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) { + return b.NamedWriteRemap(input, op, tableName, schema, nil) +} + func (b *builder) NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) { - return b.NamedWriteRemap(input, WriteOpInsert, tableName, schema, nil) + return b.NamedWrite(input, WriteOpInsert, tableName, schema) } func (b *builder) NamedDelete(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) { - return b.NamedWriteRemap(input, WriteOpDelete, tableName, schema, nil) + return b.NamedWrite(input, WriteOpDelete, tableName, schema) } func (b *builder) NamedScanRemap(tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableReadRel, error) { From 2d9ffeb07cf7532fc7d27668314c82f20b75e8f0 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 17:53:47 -0800 Subject: [PATCH 10/17] Renamed ChangeMapping to Remap (which is no longer used). --- plan/builders.go | 32 ++++++++++---------- plan/plan.go | 15 ++++----- plan/plan_builder_test.go | 58 +++++++++++++++++------------------ plan/relations.go | 64 +++++++++++++++++++-------------------- plan/relations_test.go | 16 +++++----- 5 files changed, 93 insertions(+), 92 deletions(-) diff --git a/plan/builders.go b/plan/builders.go index 2586bd0..0e3131c 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -71,18 +71,18 @@ type Builder interface { // from the output. Project(input Rel, exprs ...expr.Expression) (*ProjectRel, error) - // Deprecated: Use Project(...).ChangeMapping() instead. + // Deprecated: Use Project(...).Remap() instead. ProjectRemap(input Rel, remap []int32, exprs ...expr.Expression) (*ProjectRel, error) - // Deprecated: Use AggregateColumns(...).ChangeMapping() instead. + // Deprecated: Use AggregateColumns(...).Remap() instead. AggregateColumnsRemap(input Rel, remap []int32, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error) AggregateColumns(input Rel, measures []AggRelMeasure, groupByCols ...int32) (*AggregateRel, error) - // Deprecated: Use AggregateExprs(...).ChangeMapping() instead. + // Deprecated: Use AggregateExprs(...).Remap() instead. AggregateExprsRemap(input Rel, remap []int32, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error) AggregateExprs(input Rel, measures []AggRelMeasure, groups ...[]expr.Expression) (*AggregateRel, error) - // Deprecated: Use CreateTableAsSelect(...).ChangeMapping() instead. + // Deprecated: Use CreateTableAsSelect(...).Remap() instead. CreateTableAsSelectRemap(input Rel, remap []int32, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) CreateTableAsSelect(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) - // Deprecated: Use Cross(...).ChangeMapping() instead. + // Deprecated: Use Cross(...).Remap() instead. CrossRemap(left, right Rel, remap []int32) (*CrossRel, error) Cross(left, right Rel) (*CrossRel, error) // FetchRemap constructs a fetch relation providing an offset (skipping some @@ -91,26 +91,26 @@ type Builder interface { // returned. Similar to Fetch but allows for reordering and restricting the // returned columns. // - // Deprecated: Use Fetch(...).ChangeMapping() instead. + // Deprecated: Use Fetch(...).Remap() instead. FetchRemap(input Rel, offset, count int64, remap []int32) (*FetchRel, error) // Fetch constructs a fetch relation providing an offset (skipping some number of // rows) and a count (restricting output to a maximum number of rows). If count // is FETCH_COUNT_ALL_RECORDS (-1) all records will be returned. Fetch(input Rel, offset, count int64) (*FetchRel, error) - // Deprecated: Use Filter(...).ChangeMapping() instead. + // Deprecated: Use Filter(...).Remap() instead. FilterRemap(input Rel, condition expr.Expression, remap []int32) (*FilterRel, error) Filter(input Rel, condition expr.Expression) (*FilterRel, error) - // Deprecated: Use JoinAndFilter(...).ChangeMapping() instead. + // Deprecated: Use JoinAndFilter(...).Remap() instead. JoinAndFilterRemap(left, right Rel, condition, postJoinFilter expr.Expression, joinType JoinType, remap []int32) (*JoinRel, error) - // Deprecated: Use Fetch(...).ChangeMapping() instead. + // Deprecated: Use Fetch(...).Remap() instead. JoinAndFilter(left, right Rel, condition, postJoinFilter expr.Expression, joinType JoinType) (*JoinRel, error) - // Deprecated: Use Join(...).ChangeMapping() instead. + // Deprecated: Use Join(...).Remap() instead. JoinRemap(left, right Rel, condition expr.Expression, joinType JoinType, remap []int32) (*JoinRel, error) Join(left, right Rel, condition expr.Expression, joinType JoinType) (*JoinRel, error) - // Deprecated: Use NamedScan(...).ChangeMapping() instead. + // Deprecated: Use NamedScan(...).Remap() instead. NamedScanRemap(tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableReadRel, error) NamedScan(tableName []string, schema types.NamedStruct) *NamedTableReadRel - // Deprecated: Use NamedWrite(...).ChangeMapping() instead. + // Deprecated: Use NamedWrite(...).Remap() instead. NamedWriteRemap(input Rel, op WriteOp, tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableWriteRel, error) // NamedInsert inserts data from the input relation into a named table. NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) @@ -118,16 +118,16 @@ type Builder interface { // provided input relation, which typically includes conditions that filter // the rows to delete. NamedDelete(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) - // Deprecated: Use VirtualTable(...).ChangeMapping() instead. + // Deprecated: Use VirtualTable(...).Remap() instead. VirtualTableRemap(fields []string, remap []int32, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error) VirtualTable(fields []string, values ...expr.StructLiteralValue) (*VirtualTableReadRel, error) - // Deprecated: Use VirtualTableFromExpr(...).ChangeMapping() instead. + // Deprecated: Use VirtualTableFromExpr(...).Remap() instead. VirtualTableFromExprRemap(fieldNames []string, remap []int32, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error) VirtualTableFromExpr(fieldNames []string, values ...expr.VirtualTableExpressionValue) (*VirtualTableReadRel, error) - // Deprecated: Use Sort(...).ChangeMapping() instead. + // Deprecated: Use Sort(...).Remap() instead. SortRemap(input Rel, remap []int32, sorts ...expr.SortField) (*SortRel, error) Sort(input Rel, sorts ...expr.SortField) (*SortRel, error) - // Deprecated: Use Set(...).ChangeMapping() instead. + // Deprecated: Use Set(...).Remap() instead. SetRemap(op SetOp, remap []int32, inputs ...Rel) (*SetRel, error) Set(op SetOp, inputs ...Rel) (*SetRel, error) diff --git a/plan/plan.go b/plan/plan.go index 5c7fd2a..ac4fdf3 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -272,14 +272,15 @@ type Rel interface { // output columns from the underlying relation. OutputMapping() []int32 - // ChangeMapping modifies the current relation by applying the provided - // mapping to the current relation. Typically used to remove any unneeded - // columns or provide them in a different order. If there already is a - // mapping on this relation, this provides mapping over the current mapping. + // Remap modifies the current relation by applying the provided + // mapping to the current relation. Typically used to remove any + // unneeded columns or provide them in a different order. If there + // already is a mapping on this relation, this provides mapping over + // the current mapping. // - // If any column numbers specified are outside the currently available input - // range an error is returned and the mapping is left unchanged. - ChangeMapping(mapping ...int32) error + // If any column numbers specified are outside the currently available + // input range an error is returned and the mapping is left unchanged. + Remap(mapping ...int32) error // directOutputSchema returns the output record type of the underlying // relation as a struct type. Mapping is not applied. directOutputSchema() types.RecordType diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index f3f9ff3..50fed40 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -82,14 +82,14 @@ func TestEmitEmptyPlan(t *testing.T) { b = plan.NewBuilderDefault() root = b.NamedScan([]string{"test"}, baseSchema) - err = root.ChangeMapping() + err = root.Remap() require.NoError(t, err) _, err = b.Plan(root, []string{}) require.NoError(t, err) b = plan.NewBuilderDefault() root = b.NamedScan([]string{"test"}, baseSchema) - err = root.ChangeMapping(1, 0) + err = root.Remap(1, 0) require.NoError(t, err) p, err := b.Plan(root, []string{"a", "b"}) require.NoError(t, err) @@ -97,7 +97,7 @@ func TestEmitEmptyPlan(t *testing.T) { assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) // Verify the mapping remains the same after receiving an error. - err = root.ChangeMapping(-1) + err = root.Remap(-1) require.Error(t, err) assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) @@ -119,7 +119,7 @@ func TestBuildEmitOutOfRangePlan(t *testing.T) { b = plan.NewBuilderDefault() root := b.NamedScan([]string{"test"}, baseSchema) - err = root.ChangeMapping(2) + err = root.Remap(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -127,10 +127,10 @@ func TestBuildEmitOutOfRangePlan(t *testing.T) { func TestMappingOfMapping(t *testing.T) { b := plan.NewBuilderDefault() ns := b.NamedScan([]string{"test"}, baseSchema) - err := ns.ChangeMapping(1, 0) + err := ns.Remap(1, 0) assert.NoError(t, err) assert.Equal(t, "struct", ns.RecordType().String()) - err = ns.ChangeMapping(1) + err = ns.Remap(1) assert.NoError(t, err) assert.Equal(t, "struct", ns.RecordType().String()) } @@ -298,7 +298,7 @@ func TestAggregateRelErrors(t *testing.T) { acr, err := b.AggregateColumns(scan, nil, 0) assert.NoError(t, err) - err = acr.ChangeMapping(-1, 5) + err = acr.Remap(-1, 5) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -310,7 +310,7 @@ func TestAggregateRelErrors(t *testing.T) { ref, _ = b.RootFieldRef(scan, 0) ae, err := b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.ChangeMapping(5, -1) + err = ae.Remap(5, -1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -320,7 +320,7 @@ func TestAggregateRelErrors(t *testing.T) { ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.ChangeMapping(1) + err = ae.Remap(1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -328,14 +328,14 @@ func TestAggregateRelErrors(t *testing.T) { assert.NoError(t, err) ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.ChangeMapping(0) + err = ae.Remap(0) assert.NoError(t, err) _, err = b.AggregateColumnsRemap(scan, []int32{0}, nil, 0) assert.NoError(t, err) ae, err = b.AggregateColumns(scan, nil, 0) assert.NoError(t, err) - err = ae.ChangeMapping(0) + err = ae.Remap(0) assert.NoError(t, err) } @@ -429,7 +429,7 @@ func TestCrossRelErrors(t *testing.T) { c, err := b.Cross(left, right) assert.NoError(t, err) - err = c.ChangeMapping(-1) + err = c.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -439,7 +439,7 @@ func TestCrossRelErrors(t *testing.T) { c, err = b.Cross(left, right) assert.NoError(t, err) - err = c.ChangeMapping(5) + err = c.Remap(5) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -450,7 +450,7 @@ func TestCrossRelErrors(t *testing.T) { // Output is length 2 + 2 c, err = b.Cross(left, right) assert.NoError(t, err) - err = c.ChangeMapping(2, 3) + err = c.Remap(2, 3) assert.NoError(t, err) } @@ -512,7 +512,7 @@ func TestFetchRel(t *testing.T) { checkRoundTrip(t, expectedJSON, p) - err = fetch.ChangeMapping(0) + err = fetch.Remap(0) assert.NoError(t, err) } @@ -538,7 +538,7 @@ func TestFetchRelErrors(t *testing.T) { f, err := b.Fetch(scan, 0, 0) assert.NoError(t, err) - err = f.ChangeMapping(-1) + err = f.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -548,7 +548,7 @@ func TestFetchRelErrors(t *testing.T) { f, err = b.Fetch(scan, 0, 0) assert.NoError(t, err) - err = f.ChangeMapping(2) + err = f.Remap(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -610,7 +610,7 @@ func TestFilterRelation(t *testing.T) { checkRoundTrip(t, expectedJSON, p) - err = filter.ChangeMapping(0) + err = filter.Remap(0) assert.NoError(t, err) } @@ -648,7 +648,7 @@ func TestFilterRelationErrors(t *testing.T) { f, err := b.Filter(scan, refBool) assert.NoError(t, err) - err = f.ChangeMapping(-1) + err = f.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -658,7 +658,7 @@ func TestFilterRelationErrors(t *testing.T) { f, err = b.Filter(scan, refBool) assert.NoError(t, err) - err = f.ChangeMapping(3) + err = f.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -877,7 +877,7 @@ func TestJoinRelationError(t *testing.T) { j, err := b.Join(left, right, goodcond, plan.JoinTypeInner) assert.NoError(t, err) - err = j.ChangeMapping(-1) + err = j.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -887,7 +887,7 @@ func TestJoinRelationError(t *testing.T) { j, err = b.Join(left, right, goodcond, plan.JoinTypeLeftAnti) assert.NoError(t, err) - err = j.ChangeMapping(2) + err = j.Remap(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1118,7 +1118,7 @@ func TestSortRelationErrors(t *testing.T) { fields, _ = b.SortFields(scan, 1, 0) s, err := b.Sort(scan, fields...) assert.NoError(t, err) - err = s.ChangeMapping(-1) + err = s.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1136,7 +1136,7 @@ func TestSortRelationErrors(t *testing.T) { sortRel, err := b.Sort(scan, fields...) assert.NoError(t, err) - err = sortRel.ChangeMapping(3) + err = sortRel.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -1438,7 +1438,7 @@ func TestProjectErrors(t *testing.T) { p, err := b.Project(scan, ref) assert.NoError(t, err) - err = p.ChangeMapping(-1) + err = p.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1448,7 +1448,7 @@ func TestProjectErrors(t *testing.T) { p, err = b.Project(scan, ref) assert.NoError(t, err) - err = p.ChangeMapping(3) + err = p.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1457,7 +1457,7 @@ func TestProjectErrors(t *testing.T) { p, err = b.Project(scan, ref) assert.NoError(t, err) - err = p.ChangeMapping(2) + err = p.Remap(2) assert.NoError(t, err, "Expected expression mapping to be in-bounds") } @@ -1669,7 +1669,7 @@ func TestSetRelErrors(t *testing.T) { s, err := b.Set(plan.SetOpMinusMultiset, scan1, scan2) assert.NoError(t, err) - err = s.ChangeMapping(-1) + err = s.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1679,7 +1679,7 @@ func TestSetRelErrors(t *testing.T) { s, err = b.Set(plan.SetOpMinusMultiset, scan1, scan2) assert.NoError(t, err) - err = s.ChangeMapping(3) + err = s.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } diff --git a/plan/relations.go b/plan/relations.go index 2726276..ec8bdab 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -157,9 +157,9 @@ type MappableRel interface { OutputMapping() []int32 } -// ChangeMapping implements the core functionality of ChangeMapping for relations. +// RemapHelper implements the core functionality of RemapHelper for relations. // It returns the relation's existing mapping on an error to ease being called. -func ChangeMapping(r MappableRel, mapping []int32) ([]int32, error) { +func RemapHelper(r MappableRel, mapping []int32) ([]int32, error) { if len(mapping) == 0 { return []int32{}, nil } @@ -180,8 +180,8 @@ func ChangeMapping(r MappableRel, mapping []int32) ([]int32, error) { return mapping, nil } -func (b *baseReadRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(b, mapping) +func (b *baseReadRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(b, mapping) b.mapping = newMapping return err } @@ -623,8 +623,8 @@ func (p *ProjectRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &proj, nil } -func (p *ProjectRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(p, mapping) +func (p *ProjectRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(p, mapping) p.mapping = newMapping return err } @@ -781,8 +781,8 @@ func (j *JoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . return &join, nil } -func (j *JoinRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(j, mapping) +func (j *JoinRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(j, mapping) j.mapping = newMapping return err } @@ -848,8 +848,8 @@ func (c *CrossRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return c.Copy(newInputs...) } -func (c *CrossRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(c, mapping) +func (c *CrossRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(c, mapping) c.mapping = newMapping return err } @@ -920,8 +920,8 @@ func (f *FetchRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return f.Copy(newInputs...) } -func (f *FetchRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(f, mapping) +func (f *FetchRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(f, mapping) f.mapping = newMapping return err } @@ -1074,8 +1074,8 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &aggregate, nil } -func (ar *AggregateRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(ar, mapping) +func (ar *AggregateRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(ar, mapping) ar.mapping = newMapping return err } @@ -1161,8 +1161,8 @@ func (sr *SortRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs return &sort, nil } -func (sr *SortRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(sr, mapping) +func (sr *SortRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(sr, mapping) sr.mapping = newMapping return err } @@ -1238,8 +1238,8 @@ func (fr *FilterRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &filter, nil } -func (fr *FilterRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(fr, mapping) +func (fr *FilterRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(fr, mapping) fr.mapping = newMapping return err } @@ -1317,8 +1317,8 @@ func (s *SetRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (Rel return s.Copy(newInputs...) } -func (s *SetRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(s, mapping) +func (s *SetRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(s, mapping) s.mapping = newMapping return err } @@ -1380,8 +1380,8 @@ func (es *ExtensionSingleRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return es.Copy(newInputs...) } -func (es *ExtensionSingleRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(es, mapping) +func (es *ExtensionSingleRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(es, mapping) es.mapping = newMapping return err } @@ -1430,8 +1430,8 @@ func (el *ExtensionLeafRel) CopyWithExpressionRewrite(_ RewriteFunc, _ ...Rel) ( return el, nil } -func (el *ExtensionLeafRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(el, mapping) +func (el *ExtensionLeafRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(el, mapping) el.mapping = newMapping return err } @@ -1492,8 +1492,8 @@ func (em *ExtensionMultiRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return em.Copy(newInputs...) } -func (em *ExtensionMultiRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(em, mapping) +func (em *ExtensionMultiRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(em, mapping) em.mapping = newMapping return err } @@ -1616,8 +1616,8 @@ func (hr *HashJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInp return &join, nil } -func (hr *HashJoinRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(hr, mapping) +func (hr *HashJoinRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(hr, mapping) hr.mapping = newMapping return err } @@ -1726,8 +1726,8 @@ func (mr *MergeJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &merge, nil } -func (mr *MergeJoinRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(mr, mapping) +func (mr *MergeJoinRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(mr, mapping) mr.mapping = newMapping return err } @@ -1843,8 +1843,8 @@ func (wr *NamedTableWriteRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return wr.Copy(newInputs...) } -func (wr *NamedTableWriteRel) ChangeMapping(mapping ...int32) error { - newMapping, err := ChangeMapping(wr, mapping) +func (wr *NamedTableWriteRel) Remap(mapping ...int32) error { + newMapping, err := RemapHelper(wr, mapping) wr.mapping = newMapping return err } diff --git a/plan/relations_test.go b/plan/relations_test.go index 2b1538a..13c538c 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -382,7 +382,7 @@ func (f *fakeRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . panic("unused") } -func (f *fakeRel) ChangeMapping(mapping ...int32) error { +func (f *fakeRel) Remap(mapping ...int32) error { panic("unused") } @@ -395,7 +395,7 @@ func TestProjectRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() @@ -411,7 +411,7 @@ func TestExtensionSingleRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() @@ -425,7 +425,7 @@ func TestExtensionLeafRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -436,7 +436,7 @@ func TestExtensionMultiRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -452,7 +452,7 @@ func TestHashJoinRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) result = rel.RecordType() @@ -471,7 +471,7 @@ func TestMergeJoinRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}}) @@ -489,6 +489,6 @@ func TestNamedTableWriteRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.ChangeMapping(0) + err := rel.Remap(0) assert.ErrorContains(t, err, "output mapping index out of range") } From 4ba4eb9664ecc76c897ca3e3fbce8e5c95cc35a4 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 18:22:32 -0800 Subject: [PATCH 11/17] added missing NamedWrite interface --- plan/builders.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plan/builders.go b/plan/builders.go index 0e3131c..e227023 100644 --- a/plan/builders.go +++ b/plan/builders.go @@ -112,6 +112,8 @@ type Builder interface { NamedScan(tableName []string, schema types.NamedStruct) *NamedTableReadRel // Deprecated: Use NamedWrite(...).Remap() instead. NamedWriteRemap(input Rel, op WriteOp, tableName []string, schema types.NamedStruct, remap []int32) (*NamedTableWriteRel, error) + // NamedWrite performs the given write operation from the input relation over a named table. + NamedWrite(input Rel, op WriteOp, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) // NamedInsert inserts data from the input relation into a named table. NamedInsert(input Rel, tableName []string, schema types.NamedStruct) (*NamedTableWriteRel, error) // NamedDelete deletes rows from a specified named table based on the From b01f3b352a3759db71dd7c08f5089f04eb53435c Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 23:20:14 -0800 Subject: [PATCH 12/17] Remap now returns a copy of its input. --- plan/plan.go | 5 +- plan/plan_builder_test.go | 77 ++++++++++++---------- plan/relations.go | 132 ++++++++++++++++++++++++-------------- plan/relations_test.go | 24 +++---- 4 files changed, 144 insertions(+), 94 deletions(-) diff --git a/plan/plan.go b/plan/plan.go index ac4fdf3..3d5973b 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -255,6 +255,9 @@ type RewriteFunc func(expr.Expression) (expr.Expression, error) // It contains the common functionality between the different relations // and should be type switched to determine which relation type it actually // is for evaluation. +// +// All methods in this interface should be considered constant (i.e. +// immutable). type Rel interface { // Hint returns a set of changes to the operation which can influence // efficiency and performance but should not impact correctness. @@ -280,7 +283,7 @@ type Rel interface { // // If any column numbers specified are outside the currently available // input range an error is returned and the mapping is left unchanged. - Remap(mapping ...int32) error + Remap(mapping ...int32) (Rel, error) // directOutputSchema returns the output record type of the underlying // relation as a struct type. Mapping is not applied. directOutputSchema() types.RecordType diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 50fed40..d841e70 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -82,22 +82,22 @@ func TestEmitEmptyPlan(t *testing.T) { b = plan.NewBuilderDefault() root = b.NamedScan([]string{"test"}, baseSchema) - err = root.Remap() + newRoot, err := root.Remap() require.NoError(t, err) - _, err = b.Plan(root, []string{}) + _, err = b.Plan(newRoot, []string{}) require.NoError(t, err) b = plan.NewBuilderDefault() root = b.NamedScan([]string{"test"}, baseSchema) - err = root.Remap(1, 0) + newRoot, err = root.Remap(1, 0) require.NoError(t, err) - p, err := b.Plan(root, []string{"a", "b"}) + p, err := b.Plan(newRoot, []string{"a", "b"}) require.NoError(t, err) assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) // Verify the mapping remains the same after receiving an error. - err = root.Remap(-1) + newRoot, err = root.Remap(-1) require.Error(t, err) assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) @@ -119,7 +119,7 @@ func TestBuildEmitOutOfRangePlan(t *testing.T) { b = plan.NewBuilderDefault() root := b.NamedScan([]string{"test"}, baseSchema) - err = root.Remap(2) + _, err = root.Remap(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -127,12 +127,23 @@ func TestBuildEmitOutOfRangePlan(t *testing.T) { func TestMappingOfMapping(t *testing.T) { b := plan.NewBuilderDefault() ns := b.NamedScan([]string{"test"}, baseSchema) - err := ns.Remap(1, 0) + newRel, err := ns.Remap(1, 0) assert.NoError(t, err) - assert.Equal(t, "struct", ns.RecordType().String()) - err = ns.Remap(1) + assert.Equal(t, "struct", newRel.RecordType().String()) + newRel2, err := newRel.Remap(1) + assert.NoError(t, err) + assert.Equal(t, "struct", newRel2.RecordType().String()) +} + +func TestFailedMappingOfMapping(t *testing.T) { + b := plan.NewBuilderDefault() + ns := b.NamedScan([]string{"test"}, baseSchema) + newRel, err := ns.Remap(1, 0) assert.NoError(t, err) - assert.Equal(t, "struct", ns.RecordType().String()) + assert.Equal(t, "struct", newRel.RecordType().String()) + newRel2, err := newRel.Remap(-1) + assert.ErrorContains(t, err, "output mapping index out of range") + assert.Equal(t, "struct", newRel2.RecordType().String()) } func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) { @@ -298,7 +309,7 @@ func TestAggregateRelErrors(t *testing.T) { acr, err := b.AggregateColumns(scan, nil, 0) assert.NoError(t, err) - err = acr.Remap(-1, 5) + _, err = acr.Remap(-1, 5) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -310,7 +321,7 @@ func TestAggregateRelErrors(t *testing.T) { ref, _ = b.RootFieldRef(scan, 0) ae, err := b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.Remap(5, -1) + _, err = ae.Remap(5, -1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -320,7 +331,7 @@ func TestAggregateRelErrors(t *testing.T) { ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.Remap(1) + _, err = ae.Remap(1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -328,14 +339,14 @@ func TestAggregateRelErrors(t *testing.T) { assert.NoError(t, err) ae, err = b.AggregateExprs(scan, nil, []expr.Expression{ref}) assert.NoError(t, err) - err = ae.Remap(0) + _, err = ae.Remap(0) assert.NoError(t, err) _, err = b.AggregateColumnsRemap(scan, []int32{0}, nil, 0) assert.NoError(t, err) ae, err = b.AggregateColumns(scan, nil, 0) assert.NoError(t, err) - err = ae.Remap(0) + _, err = ae.Remap(0) assert.NoError(t, err) } @@ -429,7 +440,7 @@ func TestCrossRelErrors(t *testing.T) { c, err := b.Cross(left, right) assert.NoError(t, err) - err = c.Remap(-1) + _, err = c.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -439,7 +450,7 @@ func TestCrossRelErrors(t *testing.T) { c, err = b.Cross(left, right) assert.NoError(t, err) - err = c.Remap(5) + _, err = c.Remap(5) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -450,7 +461,7 @@ func TestCrossRelErrors(t *testing.T) { // Output is length 2 + 2 c, err = b.Cross(left, right) assert.NoError(t, err) - err = c.Remap(2, 3) + _, err = c.Remap(2, 3) assert.NoError(t, err) } @@ -512,7 +523,7 @@ func TestFetchRel(t *testing.T) { checkRoundTrip(t, expectedJSON, p) - err = fetch.Remap(0) + _, err = fetch.Remap(0) assert.NoError(t, err) } @@ -538,7 +549,7 @@ func TestFetchRelErrors(t *testing.T) { f, err := b.Fetch(scan, 0, 0) assert.NoError(t, err) - err = f.Remap(-1) + _, err = f.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -548,7 +559,7 @@ func TestFetchRelErrors(t *testing.T) { f, err = b.Fetch(scan, 0, 0) assert.NoError(t, err) - err = f.Remap(2) + _, err = f.Remap(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -610,7 +621,7 @@ func TestFilterRelation(t *testing.T) { checkRoundTrip(t, expectedJSON, p) - err = filter.Remap(0) + _, err = filter.Remap(0) assert.NoError(t, err) } @@ -648,7 +659,7 @@ func TestFilterRelationErrors(t *testing.T) { f, err := b.Filter(scan, refBool) assert.NoError(t, err) - err = f.Remap(-1) + _, err = f.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -658,7 +669,7 @@ func TestFilterRelationErrors(t *testing.T) { f, err = b.Filter(scan, refBool) assert.NoError(t, err) - err = f.Remap(3) + _, err = f.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -877,7 +888,7 @@ func TestJoinRelationError(t *testing.T) { j, err := b.Join(left, right, goodcond, plan.JoinTypeInner) assert.NoError(t, err) - err = j.Remap(-1) + _, err = j.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -887,7 +898,7 @@ func TestJoinRelationError(t *testing.T) { j, err = b.Join(left, right, goodcond, plan.JoinTypeLeftAnti) assert.NoError(t, err) - err = j.Remap(2) + _, err = j.Remap(2) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1118,7 +1129,7 @@ func TestSortRelationErrors(t *testing.T) { fields, _ = b.SortFields(scan, 1, 0) s, err := b.Sort(scan, fields...) assert.NoError(t, err) - err = s.Remap(-1) + _, err = s.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1136,7 +1147,7 @@ func TestSortRelationErrors(t *testing.T) { sortRel, err := b.Sort(scan, fields...) assert.NoError(t, err) - err = sortRel.Remap(3) + _, err = sortRel.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -1438,7 +1449,7 @@ func TestProjectErrors(t *testing.T) { p, err := b.Project(scan, ref) assert.NoError(t, err) - err = p.Remap(-1) + _, err = p.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1448,7 +1459,7 @@ func TestProjectErrors(t *testing.T) { p, err = b.Project(scan, ref) assert.NoError(t, err) - err = p.Remap(3) + _, err = p.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1457,7 +1468,7 @@ func TestProjectErrors(t *testing.T) { p, err = b.Project(scan, ref) assert.NoError(t, err) - err = p.Remap(2) + _, err = p.Remap(2) assert.NoError(t, err, "Expected expression mapping to be in-bounds") } @@ -1669,7 +1680,7 @@ func TestSetRelErrors(t *testing.T) { s, err := b.Set(plan.SetOpMinusMultiset, scan1, scan2) assert.NoError(t, err) - err = s.Remap(-1) + _, err = s.Remap(-1) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") @@ -1679,7 +1690,7 @@ func TestSetRelErrors(t *testing.T) { s, err = b.Set(plan.SetOpMinusMultiset, scan1, scan2) assert.NoError(t, err) - err = s.Remap(3) + _, err = s.Remap(3) assert.ErrorIs(t, err, substraitgo.ErrInvalidRel) assert.ErrorContains(t, err, "output mapping index out of range") } diff --git a/plan/relations.go b/plan/relations.go index ec8bdab..41fc3a3 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -180,12 +180,6 @@ func RemapHelper(r MappableRel, mapping []int32) ([]int32, error) { return mapping, nil } -func (b *baseReadRel) Remap(mapping ...int32) error { - newMapping, err := RemapHelper(b, mapping) - b.mapping = newMapping - return err -} - // NamedTableReadRel is a named scan of a base table. The list of strings // that make up the names are to represent namespacing (e.g. mydb.mytable). // This assumes a shared catalog between systems exchanging a message. @@ -246,6 +240,13 @@ func (n *NamedTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _ return &nt, nil } +func (n *NamedTableReadRel) Remap(mapping ...int32) (Rel, error) { + newMapping, err := RemapHelper(n, mapping) + newRel := n + newRel.mapping = newMapping + return newRel, err +} + // VirtualTableReadRel represents a table composed of literals. type VirtualTableReadRel struct { baseReadRel @@ -301,6 +302,13 @@ func (v *VirtualTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, return &vtr, nil } +func (v *VirtualTableReadRel) Remap(mapping ...int32) (Rel, error) { + newMapping, err := RemapHelper(v, mapping) + newRel := v + newRel.mapping = newMapping + return newRel, err +} + // ExtensionTableReadRel is a stub type that can be used to extend // and introduce new table types outside the specification by utilizing // protobuf Any type. @@ -351,6 +359,13 @@ func (e *ExtensionTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFun return &etr, nil } +func (e *ExtensionTableReadRel) Remap(mapping ...int32) (Rel, error) { + newMapping, err := RemapHelper(e, mapping) + newRel := e + newRel.mapping = newMapping + return newRel, err +} + // PathType is the type of a LocalFileReadRel's uris. type PathType int8 @@ -534,6 +549,13 @@ func (lf *LocalFileReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _ return &lfr, nil } +func (lf *LocalFileReadRel) Remap(mapping ...int32) (Rel, error) { + newMapping, err := RemapHelper(lf, mapping) + newRel := lf + newRel.mapping = newMapping + return newRel, err +} + // ProjectRel represents calculated expressions of fields (e.g. a+b), // the OutputMapping will be used to represent classical relational // projections. @@ -623,10 +645,11 @@ func (p *ProjectRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &proj, nil } -func (p *ProjectRel) Remap(mapping ...int32) error { +func (p *ProjectRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(p, mapping) - p.mapping = newMapping - return err + newRel := p + newRel.mapping = newMapping + return newRel, err } var defFilter = expr.NewPrimitiveLiteral(true, false) @@ -781,10 +804,11 @@ func (j *JoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . return &join, nil } -func (j *JoinRel) Remap(mapping ...int32) error { +func (j *JoinRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(j, mapping) - j.mapping = newMapping - return err + newRel := j + newRel.mapping = newMapping + return newRel, err } // CrossRel is a cartesian product relational operator of two tables. @@ -848,10 +872,11 @@ func (c *CrossRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return c.Copy(newInputs...) } -func (c *CrossRel) Remap(mapping ...int32) error { +func (c *CrossRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(c, mapping) - c.mapping = newMapping - return err + newRel := c + newRel.mapping = newMapping + return newRel, err } // FetchRel is a relational operator representing LIMIT/OFFSET or @@ -920,10 +945,11 @@ func (f *FetchRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R return f.Copy(newInputs...) } -func (f *FetchRel) Remap(mapping ...int32) error { +func (f *FetchRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(f, mapping) - f.mapping = newMapping - return err + newRel := f + newRel.mapping = newMapping + return newRel, err } type AggRelMeasure struct { @@ -1074,10 +1100,11 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &aggregate, nil } -func (ar *AggregateRel) Remap(mapping ...int32) error { +func (ar *AggregateRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(ar, mapping) - ar.mapping = newMapping - return err + newRel := ar + newRel.mapping = newMapping + return newRel, err } // SortRel is an ORDER BY relational operator, describing a base relation, @@ -1161,10 +1188,11 @@ func (sr *SortRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs return &sort, nil } -func (sr *SortRel) Remap(mapping ...int32) error { +func (sr *SortRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(sr, mapping) - sr.mapping = newMapping - return err + newRel := sr + newRel.mapping = newMapping + return newRel, err } // FilterRel is a relational operator capturing simple filters ( @@ -1238,10 +1266,11 @@ func (fr *FilterRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput return &filter, nil } -func (fr *FilterRel) Remap(mapping ...int32) error { +func (fr *FilterRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(fr, mapping) - fr.mapping = newMapping - return err + newRel := fr + newRel.mapping = newMapping + return newRel, err } type SetOp = proto.SetRel_SetOp @@ -1317,10 +1346,11 @@ func (s *SetRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (Rel return s.Copy(newInputs...) } -func (s *SetRel) Remap(mapping ...int32) error { +func (s *SetRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(s, mapping) - s.mapping = newMapping - return err + newRel := s + newRel.mapping = newMapping + return newRel, err } // ExtensionSingleRel is a stub to support extensions with a single input. @@ -1380,10 +1410,11 @@ func (es *ExtensionSingleRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return es.Copy(newInputs...) } -func (es *ExtensionSingleRel) Remap(mapping ...int32) error { +func (es *ExtensionSingleRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(es, mapping) - es.mapping = newMapping - return err + newRel := es + newRel.mapping = newMapping + return newRel, err } // ExtensionLeafRel is a stub to support extensions with zero inputs. @@ -1430,10 +1461,11 @@ func (el *ExtensionLeafRel) CopyWithExpressionRewrite(_ RewriteFunc, _ ...Rel) ( return el, nil } -func (el *ExtensionLeafRel) Remap(mapping ...int32) error { +func (el *ExtensionLeafRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(el, mapping) - el.mapping = newMapping - return err + newRel := el + newRel.mapping = newMapping + return newRel, err } // ExtensionMultiRel is a stub to support extensions with multiple inputs. @@ -1492,10 +1524,11 @@ func (em *ExtensionMultiRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return em.Copy(newInputs...) } -func (em *ExtensionMultiRel) Remap(mapping ...int32) error { +func (em *ExtensionMultiRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(em, mapping) - em.mapping = newMapping - return err + newRel := em + newRel.mapping = newMapping + return newRel, err } type HashMergeJoinType int8 @@ -1616,10 +1649,11 @@ func (hr *HashJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInp return &join, nil } -func (hr *HashJoinRel) Remap(mapping ...int32) error { +func (hr *HashJoinRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(hr, mapping) - hr.mapping = newMapping - return err + newRel := hr + newRel.mapping = newMapping + return newRel, err } // MergeJoinRel represents a join done by taking advantage of two sets @@ -1726,10 +1760,11 @@ func (mr *MergeJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn return &merge, nil } -func (mr *MergeJoinRel) Remap(mapping ...int32) error { +func (mr *MergeJoinRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(mr, mapping) - mr.mapping = newMapping - return err + newRel := mr + newRel.mapping = newMapping + return newRel, err } type WriteOp = proto.WriteRel_WriteOp @@ -1843,10 +1878,11 @@ func (wr *NamedTableWriteRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs return wr.Copy(newInputs...) } -func (wr *NamedTableWriteRel) Remap(mapping ...int32) error { +func (wr *NamedTableWriteRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(wr, mapping) - wr.mapping = newMapping - return err + newRel := wr + newRel.mapping = newMapping + return newRel, err } var ( diff --git a/plan/relations_test.go b/plan/relations_test.go index 13c538c..c844389 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -382,7 +382,7 @@ func (f *fakeRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . panic("unused") } -func (f *fakeRel) Remap(mapping ...int32) error { +func (f *fakeRel) Remap(mapping ...int32) (Rel, error) { panic("unused") } @@ -395,10 +395,10 @@ func TestProjectRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + newRel, err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) - result = rel.RecordType() + result = newRel.RecordType() assert.Equal(t, expected, result) } @@ -411,10 +411,10 @@ func TestExtensionSingleRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + newRel, err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) - result = rel.RecordType() + result = newRel.RecordType() assert.Equal(t, expected, result) } @@ -425,7 +425,7 @@ func TestExtensionLeafRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + _, err := rel.Remap(0) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -436,7 +436,7 @@ func TestExtensionMultiRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + _, err := rel.Remap(0) assert.ErrorContains(t, err, "output mapping index out of range") } @@ -452,10 +452,10 @@ func TestHashJoinRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + newRel, err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) - result = rel.RecordType() + result = newRel.RecordType() assert.Equal(t, expected, result) } @@ -471,11 +471,11 @@ func TestMergeJoinRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + newRel, err := rel.Remap(0) assert.NoError(t, err) expected = *types.NewRecordTypeFromTypes( []types.Type{&types.Int64Type{}}) - result = rel.RecordType() + result = newRel.RecordType() assert.Equal(t, expected, result) } @@ -489,6 +489,6 @@ func TestNamedTableWriteRecordType(t *testing.T) { result := rel.RecordType() assert.Equal(t, expected, result) - err := rel.Remap(0) + _, err := rel.Remap(0) assert.ErrorContains(t, err, "output mapping index out of range") } From 76a1bea7f6e76c8512557e263509196aa643715f Mon Sep 17 00:00:00 2001 From: David Sisson Date: Mon, 2 Dec 2024 23:27:44 -0800 Subject: [PATCH 13/17] make lint happy again --- plan/plan_builder_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index d841e70..75ea120 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -97,7 +97,7 @@ func TestEmitEmptyPlan(t *testing.T) { assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) // Verify the mapping remains the same after receiving an error. - newRoot, err = root.Remap(-1) + _, err = root.Remap(-1) require.Error(t, err) assert.Equal(t, "NSTRUCT", p.GetRoots()[0].RecordType().String()) From 3725c62c9d37c41efc9245f38129af6303314829 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Tue, 3 Dec 2024 09:30:32 -0800 Subject: [PATCH 14/17] add more coverage --- plan/relations_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/plan/relations_test.go b/plan/relations_test.go index c844389..73e7c76 100644 --- a/plan/relations_test.go +++ b/plan/relations_test.go @@ -386,6 +386,52 @@ func (f *fakeRel) Remap(mapping ...int32) (Rel, error) { panic("unused") } +func TestVirtualTableReadRelRecordType(t *testing.T) { + b := NewBuilderDefault() + rel, err := b.VirtualTable([]string{"a", "b"}, + expr.StructLiteralValue{ + &expr.PrimitiveLiteral[int64]{Value: 11, Type: &types.Int64Type{}}, + &expr.PrimitiveLiteral[string]{Value: "12", Type: &types.StringType{}}}, + expr.StructLiteralValue{ + &expr.PrimitiveLiteral[int64]{Value: 21, Type: &types.Int64Type{}}, + &expr.PrimitiveLiteral[string]{Value: "22", Type: &types.StringType{}}}) + assert.NoError(t, err) + + expected := *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}, &types.StringType{}}) + result := rel.RecordType() + assert.Equal(t, expected, result) + + newRel, err := rel.Remap(0) + assert.NoError(t, err) + expected = *types.NewRecordTypeFromTypes([]types.Type{&types.Int64Type{}}) + result = newRel.RecordType() + assert.Equal(t, expected, result) +} + +func TestExtensionTableReadRelRecordType(t *testing.T) { + // We don't have a way of setting the base schema yet so test with an empty schema. + rel := &ExtensionTableReadRel{} + + expected := *types.NewRecordTypeFromTypes(nil) + result := rel.RecordType() + assert.Equal(t, expected, result) + + _, err := rel.Remap(0) + assert.ErrorContains(t, err, "output mapping index out of range") +} + +func TestLocalFileReadRelRecordType(t *testing.T) { + // We don't have a way of setting the base schema yet so test with an empty schema. + rel := &LocalFileReadRel{} + + expected := *types.NewRecordTypeFromTypes(nil) + result := rel.RecordType() + assert.Equal(t, expected, result) + + _, err := rel.Remap(0) + assert.ErrorContains(t, err, "output mapping index out of range") +} + func TestProjectRecordType(t *testing.T) { var rel ProjectRel rel.input = &fakeRel{outputType: *types.NewRecordTypeFromTypes( From 9981ab2ed5f56b288afd110ee31dba90a748f662 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 4 Dec 2024 16:47:00 -0800 Subject: [PATCH 15/17] Fixed Remap's copy. --- plan/relations.go | 74 +++++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/plan/relations.go b/plan/relations.go index 41fc3a3..a22205b 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -158,7 +158,7 @@ type MappableRel interface { } // RemapHelper implements the core functionality of RemapHelper for relations. -// It returns the relation's existing mapping on an error to ease being called. +// It returns the relation's existing mapping or an error to ease being called. func RemapHelper(r MappableRel, mapping []int32) ([]int32, error) { if len(mapping) == 0 { return []int32{}, nil @@ -242,9 +242,9 @@ func (n *NamedTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _ func (n *NamedTableReadRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(n, mapping) - newRel := n + newRel := *n newRel.mapping = newMapping - return newRel, err + return &newRel, err } // VirtualTableReadRel represents a table composed of literals. @@ -304,9 +304,9 @@ func (v *VirtualTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, func (v *VirtualTableReadRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(v, mapping) - newRel := v + newRel := *v newRel.mapping = newMapping - return newRel, err + return &newRel, err } // ExtensionTableReadRel is a stub type that can be used to extend @@ -361,9 +361,9 @@ func (e *ExtensionTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFun func (e *ExtensionTableReadRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(e, mapping) - newRel := e + newRel := *e newRel.mapping = newMapping - return newRel, err + return &newRel, err } // PathType is the type of a LocalFileReadRel's uris. @@ -551,9 +551,9 @@ func (lf *LocalFileReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _ func (lf *LocalFileReadRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(lf, mapping) - newRel := lf + newRel := *lf newRel.mapping = newMapping - return newRel, err + return &newRel, err } // ProjectRel represents calculated expressions of fields (e.g. a+b), @@ -647,9 +647,9 @@ func (p *ProjectRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput func (p *ProjectRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(p, mapping) - newRel := p + newRel := *p newRel.mapping = newMapping - return newRel, err + return &newRel, err } var defFilter = expr.NewPrimitiveLiteral(true, false) @@ -806,9 +806,9 @@ func (j *JoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . func (j *JoinRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(j, mapping) - newRel := j + newRel := *j newRel.mapping = newMapping - return newRel, err + return &newRel, err } // CrossRel is a cartesian product relational operator of two tables. @@ -874,9 +874,9 @@ func (c *CrossRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R func (c *CrossRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(c, mapping) - newRel := c + newRel := *c newRel.mapping = newMapping - return newRel, err + return &newRel, err } // FetchRel is a relational operator representing LIMIT/OFFSET or @@ -947,9 +947,9 @@ func (f *FetchRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R func (f *FetchRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(f, mapping) - newRel := f + newRel := *f newRel.mapping = newMapping - return newRel, err + return &newRel, err } type AggRelMeasure struct { @@ -1102,9 +1102,9 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn func (ar *AggregateRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(ar, mapping) - newRel := ar + newRel := *ar newRel.mapping = newMapping - return newRel, err + return &newRel, err } // SortRel is an ORDER BY relational operator, describing a base relation, @@ -1190,9 +1190,9 @@ func (sr *SortRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs func (sr *SortRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(sr, mapping) - newRel := sr + newRel := *sr newRel.mapping = newMapping - return newRel, err + return &newRel, err } // FilterRel is a relational operator capturing simple filters ( @@ -1268,9 +1268,9 @@ func (fr *FilterRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput func (fr *FilterRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(fr, mapping) - newRel := fr + newRel := *fr newRel.mapping = newMapping - return newRel, err + return &newRel, err } type SetOp = proto.SetRel_SetOp @@ -1348,9 +1348,9 @@ func (s *SetRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (Rel func (s *SetRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(s, mapping) - newRel := s + newRel := *s newRel.mapping = newMapping - return newRel, err + return &newRel, err } // ExtensionSingleRel is a stub to support extensions with a single input. @@ -1412,9 +1412,9 @@ func (es *ExtensionSingleRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs func (es *ExtensionSingleRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(es, mapping) - newRel := es + newRel := *es newRel.mapping = newMapping - return newRel, err + return &newRel, err } // ExtensionLeafRel is a stub to support extensions with zero inputs. @@ -1463,9 +1463,9 @@ func (el *ExtensionLeafRel) CopyWithExpressionRewrite(_ RewriteFunc, _ ...Rel) ( func (el *ExtensionLeafRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(el, mapping) - newRel := el + newRel := *el newRel.mapping = newMapping - return newRel, err + return &newRel, err } // ExtensionMultiRel is a stub to support extensions with multiple inputs. @@ -1526,9 +1526,9 @@ func (em *ExtensionMultiRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs func (em *ExtensionMultiRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(em, mapping) - newRel := em + newRel := *em newRel.mapping = newMapping - return newRel, err + return &newRel, err } type HashMergeJoinType int8 @@ -1651,9 +1651,9 @@ func (hr *HashJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInp func (hr *HashJoinRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(hr, mapping) - newRel := hr + newRel := *hr newRel.mapping = newMapping - return newRel, err + return &newRel, err } // MergeJoinRel represents a join done by taking advantage of two sets @@ -1762,9 +1762,9 @@ func (mr *MergeJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn func (mr *MergeJoinRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(mr, mapping) - newRel := mr + newRel := *mr newRel.mapping = newMapping - return newRel, err + return &newRel, err } type WriteOp = proto.WriteRel_WriteOp @@ -1880,9 +1880,9 @@ func (wr *NamedTableWriteRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs func (wr *NamedTableWriteRel) Remap(mapping ...int32) (Rel, error) { newMapping, err := RemapHelper(wr, mapping) - newRel := wr + newRel := *wr newRel.mapping = newMapping - return newRel, err + return &newRel, err } var ( From f6c1ec1b2eaf9349f5bce7297f99130394287074 Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 4 Dec 2024 17:18:05 -0800 Subject: [PATCH 16/17] Added an unexported setMapping function to utilize Copy inside RemapHelper. --- plan/common.go | 4 ++ plan/plan.go | 5 ++ plan/plan_builder_test.go | 3 +- plan/relations.go | 115 ++++++++++---------------------------- 4 files changed, 40 insertions(+), 87 deletions(-) diff --git a/plan/common.go b/plan/common.go index 7cee1a9..e937a1f 100644 --- a/plan/common.go +++ b/plan/common.go @@ -54,6 +54,10 @@ func (rc *RelCommon) OutputMapping() []int32 { return mapCopy } +func (rc *RelCommon) setMapping(mapping []int32) { + rc.mapping = mapping +} + func (rc *RelCommon) GetAdvancedExtension() *extensions.AdvancedExtension { return rc.advExtension } diff --git a/plan/plan.go b/plan/plan.go index 3d5973b..1fd30b5 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -284,6 +284,11 @@ type Rel interface { // If any column numbers specified are outside the currently available // input range an error is returned and the mapping is left unchanged. Remap(mapping ...int32) (Rel, error) + + // setMapping sets the current mapping and is for internal use. + // It performs no checks. End users should call Remap() instead. + setMapping(mapping []int32) + // directOutputSchema returns the output record type of the underlying // relation as a struct type. Mapping is not applied. directOutputSchema() types.RecordType diff --git a/plan/plan_builder_test.go b/plan/plan_builder_test.go index 75ea120..fa70cfa 100644 --- a/plan/plan_builder_test.go +++ b/plan/plan_builder_test.go @@ -141,9 +141,8 @@ func TestFailedMappingOfMapping(t *testing.T) { newRel, err := ns.Remap(1, 0) assert.NoError(t, err) assert.Equal(t, "struct", newRel.RecordType().String()) - newRel2, err := newRel.Remap(-1) + _, err = newRel.Remap(-1) assert.ErrorContains(t, err, "output mapping index out of range") - assert.Equal(t, "struct", newRel2.RecordType().String()) } func checkRoundTrip(t *testing.T, expectedJSON string, p *plan.Plan) { diff --git a/plan/relations.go b/plan/relations.go index a22205b..07dfdef 100644 --- a/plan/relations.go +++ b/plan/relations.go @@ -152,32 +152,31 @@ func (b *baseReadRel) updateFilters(filters []expr.Expression) { b.filter, b.bestEffortFilter = filters[0], filters[1] } -type MappableRel interface { - RecordType() types.RecordType - OutputMapping() []int32 -} - // RemapHelper implements the core functionality of RemapHelper for relations. -// It returns the relation's existing mapping or an error to ease being called. -func RemapHelper(r MappableRel, mapping []int32) ([]int32, error) { +func RemapHelper(r Rel, mapping []int32) (Rel, error) { + newRel, err := r.Copy(r.GetInputs()...) + if err != nil { + return nil, err + } if len(mapping) == 0 { - return []int32{}, nil + newRel.setMapping([]int32{}) + return newRel, nil } nOutput := r.RecordType().FieldCount() oldMapping := r.OutputMapping() newMapping := make([]int32, 0, len(mapping)) for _, idx := range mapping { if idx < 0 || idx >= nOutput { - return r.OutputMapping(), errOutputMappingOutOfRange + return nil, errOutputMappingOutOfRange } if len(oldMapping) > 0 { newMapping = append(newMapping, oldMapping[idx]) + } else { + newMapping = append(newMapping, idx) } } - if len(oldMapping) > 0 { - return newMapping, nil - } - return mapping, nil + newRel.setMapping(newMapping) + return newRel, nil } // NamedTableReadRel is a named scan of a base table. The list of strings @@ -241,10 +240,7 @@ func (n *NamedTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _ } func (n *NamedTableReadRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(n, mapping) - newRel := *n - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(n, mapping) } // VirtualTableReadRel represents a table composed of literals. @@ -303,10 +299,7 @@ func (v *VirtualTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, } func (v *VirtualTableReadRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(v, mapping) - newRel := *v - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(v, mapping) } // ExtensionTableReadRel is a stub type that can be used to extend @@ -360,10 +353,7 @@ func (e *ExtensionTableReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFun } func (e *ExtensionTableReadRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(e, mapping) - newRel := *e - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(e, mapping) } // PathType is the type of a LocalFileReadRel's uris. @@ -550,10 +540,7 @@ func (lf *LocalFileReadRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, _ } func (lf *LocalFileReadRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(lf, mapping) - newRel := *lf - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(lf, mapping) } // ProjectRel represents calculated expressions of fields (e.g. a+b), @@ -646,10 +633,7 @@ func (p *ProjectRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput } func (p *ProjectRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(p, mapping) - newRel := *p - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(p, mapping) } var defFilter = expr.NewPrimitiveLiteral(true, false) @@ -805,10 +789,7 @@ func (j *JoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs . } func (j *JoinRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(j, mapping) - newRel := *j - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(j, mapping) } // CrossRel is a cartesian product relational operator of two tables. @@ -873,10 +854,7 @@ func (c *CrossRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R } func (c *CrossRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(c, mapping) - newRel := *c - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(c, mapping) } // FetchRel is a relational operator representing LIMIT/OFFSET or @@ -946,10 +924,7 @@ func (f *FetchRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (R } func (f *FetchRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(f, mapping) - newRel := *f - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(f, mapping) } type AggRelMeasure struct { @@ -1101,10 +1076,7 @@ func (ar *AggregateRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn } func (ar *AggregateRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(ar, mapping) - newRel := *ar - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(ar, mapping) } // SortRel is an ORDER BY relational operator, describing a base relation, @@ -1189,10 +1161,7 @@ func (sr *SortRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInputs } func (sr *SortRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(sr, mapping) - newRel := *sr - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(sr, mapping) } // FilterRel is a relational operator capturing simple filters ( @@ -1267,10 +1236,7 @@ func (fr *FilterRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInput } func (fr *FilterRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(fr, mapping) - newRel := *fr - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(fr, mapping) } type SetOp = proto.SetRel_SetOp @@ -1347,10 +1313,7 @@ func (s *SetRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs ...Rel) (Rel } func (s *SetRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(s, mapping) - newRel := *s - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(s, mapping) } // ExtensionSingleRel is a stub to support extensions with a single input. @@ -1411,10 +1374,7 @@ func (es *ExtensionSingleRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs } func (es *ExtensionSingleRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(es, mapping) - newRel := *es - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(es, mapping) } // ExtensionLeafRel is a stub to support extensions with zero inputs. @@ -1462,10 +1422,7 @@ func (el *ExtensionLeafRel) CopyWithExpressionRewrite(_ RewriteFunc, _ ...Rel) ( } func (el *ExtensionLeafRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(el, mapping) - newRel := *el - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(el, mapping) } // ExtensionMultiRel is a stub to support extensions with multiple inputs. @@ -1525,10 +1482,7 @@ func (em *ExtensionMultiRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs } func (em *ExtensionMultiRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(em, mapping) - newRel := *em - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(em, mapping) } type HashMergeJoinType int8 @@ -1650,10 +1604,7 @@ func (hr *HashJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newInp } func (hr *HashJoinRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(hr, mapping) - newRel := *hr - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(hr, mapping) } // MergeJoinRel represents a join done by taking advantage of two sets @@ -1761,10 +1712,7 @@ func (mr *MergeJoinRel) CopyWithExpressionRewrite(rewriteFunc RewriteFunc, newIn } func (mr *MergeJoinRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(mr, mapping) - newRel := *mr - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(mr, mapping) } type WriteOp = proto.WriteRel_WriteOp @@ -1879,10 +1827,7 @@ func (wr *NamedTableWriteRel) CopyWithExpressionRewrite(_ RewriteFunc, newInputs } func (wr *NamedTableWriteRel) Remap(mapping ...int32) (Rel, error) { - newMapping, err := RemapHelper(wr, mapping) - newRel := *wr - newRel.mapping = newMapping - return &newRel, err + return RemapHelper(wr, mapping) } var ( From 0178a4f29a4d1bb2bf58bcbc3b14ba9c81b788da Mon Sep 17 00:00:00 2001 From: David Sisson Date: Wed, 4 Dec 2024 17:22:45 -0800 Subject: [PATCH 17/17] cleaned up a comment --- plan/plan.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/plan/plan.go b/plan/plan.go index 1fd30b5..2dbc5ff 100644 --- a/plan/plan.go +++ b/plan/plan.go @@ -256,8 +256,7 @@ type RewriteFunc func(expr.Expression) (expr.Expression, error) // and should be type switched to determine which relation type it actually // is for evaluation. // -// All methods in this interface should be considered constant (i.e. -// immutable). +// All the exported methods in this interface should be considered constant. type Rel interface { // Hint returns a set of changes to the operation which can influence // efficiency and performance but should not impact correctness.