diff --git a/expr/functions.go b/expr/functions.go index ce21000..4f6f3f5 100644 --- a/expr/functions.go +++ b/expr/functions.go @@ -140,6 +140,13 @@ func BoundFromProto(b *proto.Expression_WindowFunction_Bound) Bound { return nil } +type FunctionInvocation interface { + CompoundName() string + ID() extensions.ID + GetOptions() []*types.FunctionOption + GetArgTypes() []types.Type +} + type ScalarFunction struct { funcRef uint32 declaration *extensions.ScalarFunctionVariant @@ -324,6 +331,25 @@ func (s *ScalarFunction) GetOption(name string) []string { return nil } +func (s *ScalarFunction) GetOptions() []*types.FunctionOption { return s.options } + +func getArgTypes(args []types.FuncArg) []types.Type { + argTypes := make([]types.Type, len(args)) + for i, arg := range args { + switch a := arg.(type) { + case Expression: + argTypes[i] = a.GetType() + case types.Type: + argTypes[i] = a + } + } + return argTypes +} + +func (s *ScalarFunction) GetArgTypes() []types.Type { + return getArgTypes(s.args) +} + func (s *ScalarFunction) GetType() types.Type { return s.outputType } func (s *ScalarFunction) ToProtoFuncArg() *proto.FunctionArgument { return &proto.FunctionArgument{ @@ -540,6 +566,12 @@ func (w *WindowFunction) String() string { return b.String() } +func (w *WindowFunction) GetOptions() []*types.FunctionOption { return w.options } + +func (w *WindowFunction) GetArgTypes() []types.Type { + return getArgTypes(w.args) +} + func (w *WindowFunction) GetType() types.Type { return w.outputType } func (w *WindowFunction) Equals(other Expression) bool { rhs, ok := other.(*WindowFunction) @@ -807,6 +839,12 @@ func (a *AggregateFunction) GetOption(name string) []string { return nil } +func (a *AggregateFunction) GetOptions() []*types.FunctionOption { return a.options } + +func (a *AggregateFunction) GetArgTypes() []types.Type { + return getArgTypes(a.args) +} + func (a *AggregateFunction) GetType() types.Type { return a.outputType } func (a *AggregateFunction) ToProto() *proto.AggregateFunction { diff --git a/extensions/variants.go b/extensions/variants.go index 06f1800..fa056cb 100644 --- a/extensions/variants.go +++ b/extensions/variants.go @@ -223,7 +223,7 @@ func validateVariadicBehaviorForMatch(variadicBehavior *VariadicBehavior, actual // all concrete types must be equal for all variable arguments firstVariadicArgIdx := max(variadicBehavior.Min-1, 0) for i := firstVariadicArgIdx; i < len(actualTypes)-1; i++ { - if !actualTypes[i].Equals(actualTypes[i+1]) { + if !actualTypes[i].Equals(actualTypes[i+1].WithNullability(actualTypes[i].GetNullability())) { return false } } diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 738ea1c..868e33d 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -8,8 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait" - "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" + "github.com/substrait-io/substrait-go/v3/functions" parser2 "github.com/substrait-io/substrait-go/v3/testcases/parser" "github.com/substrait-io/substrait-go/v3/types" "github.com/substrait-io/substrait-go/v3/types/integer_parameters" @@ -229,16 +229,16 @@ func TestMatchWithSyncParams(t *testing.T) { require.NotNil(t, testFile) assert.Len(t, testFile.TestCases, testFileInfo.numTests) - reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) for _, tc := range testFile.TestCases { t.Run(tc.FuncName, func(t *testing.T) { switch tc.FuncType { case parser2.ScalarFuncType: - invocation, err := tc.GetScalarFunctionInvocation(®) + invocation, err := tc.GetScalarFunctionInvocation(®, funcRegistry) require.NoError(t, err) require.Equal(t, tc.ID(), invocation.ID()) case parser2.AggregateFuncType: - invocation, err := tc.GetAggregateFunctionInvocation(®) + invocation, err := tc.GetAggregateFunctionInvocation(®, funcRegistry) require.NoError(t, err) require.Equal(t, tc.ID(), invocation.ID()) } diff --git a/functions/dialect.go b/functions/dialect.go index f85b465..49b652e 100644 --- a/functions/dialect.go +++ b/functions/dialect.go @@ -38,6 +38,8 @@ type dialectImpl struct { localScalarFunctions map[extensions.ID]*dialectFunctionInfo localAggregateFunctions map[extensions.ID]*dialectFunctionInfo localWindowFunctions map[extensions.ID]*dialectFunctionInfo + + localTypeRegistry LocalTypeRegistry } func (d *dialectImpl) Name() string { @@ -52,6 +54,10 @@ func appendVariants[T extensions.FunctionVariant](variants []extensions.Function } func (d *dialectImpl) LocalizeFunctionRegistry(registry FunctionRegistry) (LocalFunctionRegistry, error) { + localTypeRegistry, err := d.GetLocalTypeRegistry() + if err != nil { + return nil, err + } scalarFunctions, err := makeLocalFunctionVariantMapAndSlice(d.localScalarFunctions, registry.GetScalarFunctionsByName, newLocalScalarFunctionVariant) if err != nil { return nil, err @@ -71,11 +77,13 @@ func (d *dialectImpl) LocalizeFunctionRegistry(registry FunctionRegistry) (Local allVariants = appendVariants(allVariants, windowFunctions.variantsSlice) return &localFunctionRegistryImpl{ - dialect: d, - scalarFunctions: scalarFunctions.variantsMap, - aggregateFunctions: aggregateFunctions.variantsMap, - windowFunctions: windowFunctions.variantsMap, - allFunctions: allVariants, + dialect: d, + scalarFunctions: scalarFunctions.variantsMap, + aggregateFunctions: aggregateFunctions.variantsMap, + windowFunctions: windowFunctions.variantsMap, + allFunctions: allVariants, + idToLocalFunctionMap: makeLocalFunctionVariantsMap(allVariants), + localTypeRegistry: localTypeRegistry, }, nil } @@ -151,6 +159,9 @@ func addToSliceMapWithLocalKey[V localFunctionVariant](m map[FunctionName][]V, v } func (d *dialectImpl) LocalizeTypeRegistry(TypeRegistry) (LocalTypeRegistry, error) { + if d.localTypeRegistry != nil { + return d.localTypeRegistry, nil + } typeInfos := make([]typeInfo, 0, len(d.toLocalTypeMap)) for name, info := range d.toLocalTypeMap { // TODO use registry.GetTypeClasses @@ -158,9 +169,17 @@ func (d *dialectImpl) LocalizeTypeRegistry(TypeRegistry) (LocalTypeRegistry, err if err != nil { return nil, fmt.Errorf("%w: unknown type %v", substraitgo.ErrInvalidDialect, name) } - typeInfos = append(typeInfos, typeInfo{typ: typ, shortName: name, localName: info.SqlTypeName, supportedAsColumn: info.SupportedAsColumn}) + typeInfos = append(typeInfos, typeInfo{typ: typ, shortName: typ.ShortString(), localName: info.SqlTypeName, supportedAsColumn: info.SupportedAsColumn}) + } + d.localTypeRegistry = NewLocalTypeRegistry(typeInfos) + return d.localTypeRegistry, nil +} + +func (d *dialectImpl) GetLocalTypeRegistry() (LocalTypeRegistry, error) { + if d.localTypeRegistry != nil { + return d.localTypeRegistry, nil } - return NewLocalTypeRegistry(typeInfos), nil + return d.LocalizeTypeRegistry(NewTypeRegistry()) } func newDialect(name string, reader io.Reader) (Dialect, error) { diff --git a/functions/dialect_test.go b/functions/dialect_test.go index 7e78033..63e0d7f 100644 --- a/functions/dialect_test.go +++ b/functions/dialect_test.go @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -package functions +package functions_test import ( "strings" @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/substrait-io/substrait-go/v3/extensions" + . "github.com/substrait-io/substrait-go/v3/functions" "github.com/substrait-io/substrait-go/v3/types" ) diff --git a/functions/local_functions.go b/functions/local_functions.go index e423889..5592160 100644 --- a/functions/local_functions.go +++ b/functions/local_functions.go @@ -1,6 +1,11 @@ package functions -import "github.com/substrait-io/substrait-go/v3/extensions" +import ( + "fmt" + + "github.com/substrait-io/substrait-go/v3/expr" + "github.com/substrait-io/substrait-go/v3/extensions" +) type localFunctionRegistryImpl struct { dialect Dialect @@ -11,6 +16,24 @@ type localFunctionRegistryImpl struct { windowFunctions map[FunctionName][]*LocalWindowFunctionVariant allFunctions []extensions.FunctionVariant + + idToLocalFunctionMap map[extensions.ID]localFunctionVariant + localTypeRegistry LocalTypeRegistry +} + +func makeLocalFunctionVariantsMap(functions []extensions.FunctionVariant) map[extensions.ID]localFunctionVariant { + localFunctionVariants := make(map[extensions.ID]localFunctionVariant) + for _, f := range functions { + switch variant := f.(type) { + case *LocalScalarFunctionVariant: + localFunctionVariants[variant.ID()] = variant + case *LocalAggregateFunctionVariant: + localFunctionVariants[variant.ID()] = variant + case *LocalWindowFunctionVariant: + localFunctionVariants[variant.ID()] = variant + } + } + return localFunctionVariants } func (l *localFunctionRegistryImpl) GetAllFunctions() []extensions.FunctionVariant { @@ -33,4 +56,39 @@ func (l *localFunctionRegistryImpl) GetWindowFunctions(name FunctionName, numArg return getFunctionVariantsByCount(getOrEmpty(name, l.windowFunctions), numArgs) } +func (l *localFunctionRegistryImpl) GetScalarFunctionByInvocation(scalarFuncInvocation *expr.ScalarFunction) (*LocalScalarFunctionVariant, error) { + return getFunctionVariantByInvocation[*LocalScalarFunctionVariant](scalarFuncInvocation, l) +} + +func (l *localFunctionRegistryImpl) GetAggregateFunctionByInvocation(aggregateFuncInvocation *expr.AggregateFunction) (*LocalAggregateFunctionVariant, error) { + return getFunctionVariantByInvocation[*LocalAggregateFunctionVariant](aggregateFuncInvocation, l) +} + +func (l *localFunctionRegistryImpl) GetWindowFunctionByInvocation(windowFuncInvocation *expr.WindowFunction) (*LocalWindowFunctionVariant, error) { + return getFunctionVariantByInvocation[*LocalWindowFunctionVariant](windowFuncInvocation, l) +} + +func getFunctionVariantByInvocation[V localFunctionVariant](invocation expr.FunctionInvocation, registry *localFunctionRegistryImpl) (V, error) { + var zeroV V + f, ok := registry.idToLocalFunctionMap[invocation.ID()] + if !ok { + return zeroV, fmt.Errorf("function variant not found for function: %s", invocation.ID()) + } + argTypes := invocation.GetArgTypes() + for i, argType := range argTypes { + _, err := registry.localTypeRegistry.GetLocalTypeFromSubstraitType(argType) + if err != nil { + return zeroV, fmt.Errorf("unsupported substrait type: %v as argument %d in %s", argType, i, invocation.CompoundName()) + } + } + for _, option := range invocation.GetOptions() { + for _, value := range option.Preference { + if !f.IsOptionSupported(option.Name, value) { + return zeroV, fmt.Errorf("unsupported option [%s:%s] in function %s", option.Name, value, invocation.CompoundName()) + } + } + } + return f.(V), nil +} + var _ LocalFunctionRegistry = &localFunctionRegistryImpl{} diff --git a/functions/local_functions_test.go b/functions/local_functions_test.go new file mode 100644 index 0000000..05f2821 --- /dev/null +++ b/functions/local_functions_test.go @@ -0,0 +1,241 @@ +package functions_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait-go/v3/extensions" + "github.com/substrait-io/substrait-go/v3/functions" + parser2 "github.com/substrait-io/substrait-go/v3/testcases/parser" +) + +func makeHeader(version, include string) string { + return fmt.Sprintf("### SUBSTRAIT_SCALAR_TEST: %s\n### SUBSTRAIT_INCLUDE: '%s'\n\n", version, include) +} + +func makeAggregateTestHeader(version, include string) string { + return fmt.Sprintf("### SUBSTRAIT_AGGREGATE_TEST: %s\n### SUBSTRAIT_INCLUDE: '%s'\n\n", version, include) +} + +var scalarFunctionDialectYaml = ` +name: test +type: sql +dependencies: + arithmetic: + https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml +supported_types: + i8: + sql_type_name: INTEGER + supported_as_column: true + i16: + sql_type_name: INTEGER + supported_as_column: true + i32: + sql_type_name: INTEGER + supported_as_column: true + i64: + sql_type_name: BIGINT + supported_as_column: true + fp64: + sql_type_name: DOUBLE + supported_as_column: true + bool: + sql_type_name: BOOLEAN + supported_as_column: true +scalar_functions: +- name: arithmetic.add + local_name: + + infix: true + required_options: + overflow: ERROR + rounding: TIE_TO_EVEN + supported_kernels: + - i8_i8 + - i16_i16 + - i32_i32 + - i64_i64 + - fp64_fp64 +` + +func TestGetLocalScalarFunctionByInvocation(t *testing.T) { + header := makeHeader("v1.0", "/extensions/functions_arithmetic.yaml") + testsStr := `# 'Basic examples without any special cases' +add(120::i8, 5::i8) = 125::i8 +add(126::i16, 5::i16) = 125::i16 +add(3.4e+38::fp32, 3.4e+38::fp32) = inf::fp32 + +# Overflow examples demonstrating overflow behavior +add(2000000000::i32, 2000000000::i32) [overflow:ERROR] = +add(9223372036854775807::i64, 1::i64) [overflow:ERROR] = +add(120::i8, 10::i8) [overflow:SATURATE] = 127::i8 +add(120::i8, 10::i8) [overflow:SILENT] = + +` + + testResults := []struct { + name string + expectedError string + }{ + {"add:i8_i8", ""}, + {"add:i16_i16", ""}, + {"add:fp32_f32", "function variant not found"}, + {"add:i32_i32 [overflow:ERROR]", ""}, + {"add:i64_i64 [overflow:ERROR]", ""}, + {"add:i8_i8 [overflow:SATURATE]", "unsupported option [overflow:SATURATE]"}, + {"add:i8_i8 [overflow:SILENT]", "unsupported option [overflow:SILENT]"}, + } + localRegistry := getLocalFunctionRegistry(t, scalarFunctionDialectYaml, gFunctionRegistry) + + testFile, err := parser2.ParseTestCasesFromString(header + testsStr) + require.NoError(t, err) + require.NotNil(t, testFile) + assert.Len(t, testFile.TestCases, len(testResults)) + require.GreaterOrEqual(t, len(testFile.TestCases), len(testResults)) + + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + for i, result := range testResults { + tc := testFile.TestCases[i] + t.Run(result.name, func(t *testing.T) { + require.Equal(t, tc.FuncType, parser2.ScalarFuncType) + invocation, err := tc.GetScalarFunctionInvocation(®, funcRegistry) + require.NoError(t, err) + require.Equal(t, tc.ID(), invocation.ID()) + localVariant, err := localRegistry.GetScalarFunctionByInvocation(invocation) + if result.expectedError == "" { + require.NoError(t, err) + require.NotNil(t, localVariant) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, result.expectedError) + } + }) + } +} + +var aggregateFunctionDialectYaml = ` +name: test +type: sql +dependencies: + arithmetic: + https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml + arithmetic_decimal: + https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic_decimal.yaml + comparison: + https://github.com/substrait-io/substrait/blob/main/extensions/functions_comparison.yaml + string: + https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml +supported_types: + i8: + sql_type_name: INTEGER + supported_as_column: true + i16: + sql_type_name: INTEGER + supported_as_column: true + i32: + sql_type_name: INTEGER + supported_as_column: true + i64: + sql_type_name: BIGINT + supported_as_column: true + fp64: + sql_type_name: DOUBLE + supported_as_column: true + decimal: + sql_type_name: NUMBER + supported_as_column: true +aggregate_functions: + - name: arithmetic.sum + aggregate: true + required_options: + overflow: ERROR + rounding: TIE_TO_EVEN + supported_kernels: + - fp64 + - i64 + - name: arithmetic.min + aggregate: true + supported_kernels: + - i16 + - i32 + - name: arithmetic.avg + aggregate: true + supported_kernels: + - fp64 + - name: arithmetic_decimal.min + aggregate: true + supported_kernels: + - dec + - name: arithmetic_decimal.sum + aggregate: true + supported_kernels: + - dec +` + +func TestGetLocalAggregateFunctionByInvocation(t *testing.T) { + header := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic.yaml") + testsStr := `# basic +avg((1,2,3)::fp32) = 2::fp64 +avg((1,2,3)::fp64) = 2::fp64 +min((20, -3, 1, -10, 0, 5)::i8) = -10::i8 +min((-32768, 32767, 20000, -30000)::i16) = -32768::i16 +min((-214748648, 214748647, 21470048, 4000000)::i32) = -214748648::i32 +sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = +sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:SILENT] = -9223372036854775806::i64 +sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:SATURATE] = -9223372036854775806::i64 +sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875)::fp64) = 16.500002145767212::fp64 +` + decHeader := makeAggregateTestHeader("v1.0", "/extensions/functions_arithmetic_decimal.yaml") + decTestsStr := `# basic +min((20, -3, 1, -10, 0, 5)::dec<2, 0>) = -10::dec<2, 0> +min((-32768, 32767, 20000, -30000)::dec<5, 0>) = -32768::dec<5, 0> +sum((2.5000007152557373046875, 7.0000007152557373046875, 0, 7.0000007152557373046875)::dec<23, 22>) = 16.5000021457672119140625::dec<38, 22> +` + + testResults := []struct { + name string + expectedError string + }{ + {"avg:fp32", "function variant not found"}, + {"avg:fp64", ""}, + {"min:i8", "function variant not found"}, + {"min:i16", ""}, + {"min:i32", ""}, + {"sum:i64", ""}, + {"sum:i64 [overflow:SILENT]", "unsupported option [overflow:SILENT]"}, + {"sum:i64 [overflow:SATURATE]", "unsupported option [overflow:SATURATE]"}, + {"sum:fp64", ""}, + {"min:dec", ""}, + {"min:dec", ""}, + {"sum:dec", ""}, + } + localRegistry := getLocalFunctionRegistry(t, aggregateFunctionDialectYaml, gFunctionRegistry) + testFile, err := parser2.ParseTestCasesFromString(header + testsStr) + require.NoError(t, err) + require.NotNil(t, testFile) + testFile1, err := parser2.ParseTestCasesFromString(decHeader + decTestsStr) + require.NoError(t, err) + require.NotNil(t, testFile) + testCases := append(testFile.TestCases, testFile1.TestCases...) + require.GreaterOrEqual(t, len(testCases), len(testResults)) + + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) + for i, result := range testResults { + tc := testCases[i] + t.Run(result.name, func(t *testing.T) { + require.Equal(t, tc.FuncType, parser2.AggregateFuncType) + invocation, err := tc.GetAggregateFunctionInvocation(®, funcRegistry) + require.NoError(t, err) + require.Equal(t, tc.ID(), invocation.ID()) + localVariant, err := localRegistry.GetAggregateFunctionByInvocation(invocation) + if result.expectedError == "" { + require.NoError(t, err) + require.NotNil(t, localVariant) + } else { + require.Error(t, err) + assert.ErrorContains(t, err, result.expectedError) + } + }) + } +} diff --git a/functions/registries.go b/functions/registries.go index f1f891f..cc9aaa0 100644 --- a/functions/registries.go +++ b/functions/registries.go @@ -3,6 +3,7 @@ package functions import ( "strings" + "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" "github.com/substrait-io/substrait-go/v3/types" ) @@ -85,6 +86,9 @@ type FunctionRegistry interface { type LocalFunctionRegistry interface { functionRegistryBase[FunctionName, *LocalScalarFunctionVariant, *LocalAggregateFunctionVariant, *LocalWindowFunctionVariant] GetDialect() Dialect + GetScalarFunctionByInvocation(scalarFuncInvocation *expr.ScalarFunction) (*LocalScalarFunctionVariant, error) + GetAggregateFunctionByInvocation(aggregateFuncInvocation *expr.AggregateFunction) (*LocalAggregateFunctionVariant, error) + GetWindowFunctionByInvocation(windowFuncInvocation *expr.WindowFunction) (*LocalWindowFunctionVariant, error) } type FunctionNotation int @@ -184,3 +188,7 @@ func newLocalWindowFunctionVariant(wf *extensions.WindowFunctionVariant, dfi *di }, } } + +func NewExtensionAndFunctionRegistries(c *extensions.Collection) (expr.ExtensionRegistry, FunctionRegistry) { + return expr.NewEmptyExtensionRegistry(c), NewFunctionRegistry(c) +} diff --git a/functions/types_test.go b/functions/types_test.go index cc2a965..abc6cbd 100644 --- a/functions/types_test.go +++ b/functions/types_test.go @@ -1,4 +1,4 @@ -package functions +package functions_test import ( "strings" @@ -6,6 +6,7 @@ import ( "github.com/stretchr/testify/assert" substraitgo "github.com/substrait-io/substrait-go/v3" + . "github.com/substrait-io/substrait-go/v3/functions" "github.com/substrait-io/substrait-go/v3/types" ) diff --git a/testcases/parser/nodes.go b/testcases/parser/nodes.go index 7159fcd..b9fbedb 100644 --- a/testcases/parser/nodes.go +++ b/testcases/parser/nodes.go @@ -5,8 +5,10 @@ import ( "strconv" "strings" + substraitgo "github.com/substrait-io/substrait-go/v3" "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" + "github.com/substrait-io/substrait-go/v3/functions" "github.com/substrait-io/substrait-go/v3/types" ) @@ -148,7 +150,7 @@ func (tc *TestCase) ID() extensions.ID { } } -func (tc *TestCase) GetScalarFunctionInvocation(reg *expr.ExtensionRegistry) (*expr.ScalarFunction, error) { +func (tc *TestCase) GetScalarFunctionInvocation(reg *expr.ExtensionRegistry, funcRegistry functions.FunctionRegistry) (*expr.ScalarFunction, error) { if tc.FuncType != ScalarFuncType { return nil, fmt.Errorf("not a scalar function testcase") } @@ -158,10 +160,26 @@ func (tc *TestCase) GetScalarFunctionInvocation(reg *expr.ExtensionRegistry) (*e args[i] = arg.Value } - return expr.NewScalarFunc(*reg, id, tc.GetFunctionOptions(), args...) + invocation, err := expr.NewScalarFunc(*reg, id, tc.GetFunctionOptions(), args...) + if err == nil { + return invocation, nil + } + + // exact match not found, try to find a function that matches with function parameter type "any" + funcVariants := funcRegistry.GetScalarFunctions(tc.FuncName, len(args)) + for _, function := range funcVariants { + isMatch, err1 := function.Match(tc.GetArgTypes()) + if err1 == nil && isMatch && function.ID().URI == id.URI { + return expr.NewScalarFunc(*reg, function.ID(), tc.GetFunctionOptions(), args...) + } + } + return nil, fmt.Errorf("%w: no matching function found or %s", substraitgo.ErrNotFound, id) } -func (tc *TestCase) GetAggregateFunctionInvocation(reg *expr.ExtensionRegistry) (*expr.AggregateFunction, error) { +func (tc *TestCase) GetAggregateFunctionInvocation(reg *expr.ExtensionRegistry, funcRegistry functions.FunctionRegistry) (*expr.AggregateFunction, error) { + if tc.FuncType != AggregateFuncType { + return nil, fmt.Errorf("not an aggregate function testcase") + } id := tc.ID() args := make([]types.FuncArg, len(tc.AggregateArgs)) baseSchema := types.NewRecordTypeFromTypes(tc.getAggregateFuncTableSchema()) @@ -178,8 +196,21 @@ func (tc *TestCase) GetAggregateFunctionInvocation(reg *expr.ExtensionRegistry) args[i] = fieldRef } - return expr.NewAggregateFunc(*reg, id, tc.GetFunctionOptions(), + invocation, err := expr.NewAggregateFunc(*reg, id, tc.GetFunctionOptions(), types.AggInvocationAll, types.AggPhaseInitialToResult, nil, args...) + if err == nil { + return invocation, nil + } + + funcVariants := funcRegistry.GetAggregateFunctions(tc.FuncName, len(args)) + for _, function := range funcVariants { + isMatch, err := function.Match(tc.GetArgTypes()) + if err == nil && isMatch && function.ID().URI == id.URI { + return expr.NewAggregateFunc(*reg, function.ID(), tc.GetFunctionOptions(), + types.AggInvocationAll, types.AggPhaseInitialToResult, nil, args...) + } + } + return nil, fmt.Errorf("%w: no matching function found or %s", substraitgo.ErrNotFound, id) } type TestGroup struct { diff --git a/testcases/parser/parse_test.go b/testcases/parser/parse_test.go index c89095d..67bfb72 100644 --- a/testcases/parser/parse_test.go +++ b/testcases/parser/parse_test.go @@ -1,7 +1,9 @@ package parser import ( + "embed" "fmt" + "io/fs" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +11,7 @@ import ( "github.com/substrait-io/substrait" "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" + "github.com/substrait-io/substrait-go/v3/functions" "github.com/substrait-io/substrait-go/v3/literal" "github.com/substrait-io/substrait-go/v3/types" ) @@ -42,16 +45,17 @@ add(120::i8, 10::i8) [overflow:ERROR] = {&types.Int16Type{}, &types.Int16Type{}}, {&types.Int8Type{}, &types.Int8Type{}}, } - reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) for i, tc := range testFile.TestCases { assert.Equal(t, extensions.ID{URI: arithURI, Name: ids[i]}, tc.ID()) - scalarFunc, err1 := tc.GetScalarFunctionInvocation(®) + scalarFunc, err1 := tc.GetScalarFunctionInvocation(®, funcRegistry) require.NoError(t, err1) assert.Equal(t, tc.FuncName, scalarFunc.Name()) require.Equal(t, 2, scalarFunc.NArgs()) assert.Equal(t, tc.Args[0].Value, scalarFunc.Arg(0)) assert.Equal(t, tc.Args[1].Value, scalarFunc.Arg(1)) assert.Equal(t, argTypes[i], tc.GetArgTypes()) + assert.Equal(t, ids[i], tc.CompoundFunctionName()) } } @@ -270,7 +274,7 @@ func TestParseAggregateFunc(t *testing.T) { avg((1,2,3)::fp32) = 2::fp64 sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = ` - reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection) arithUri := "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" testFile, err := ParseTestCasesFromString(header + tests) require.NoError(t, err) @@ -293,10 +297,11 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] =