From d0f96f5770e745bb5260418366021598a119a63e Mon Sep 17 00:00:00 2001 From: Chandra Sanapala Date: Tue, 17 Dec 2024 21:00:06 +0530 Subject: [PATCH] fix: function variant match to deal with sync type parameters in the function parameters --- extensions/variants.go | 6 +++--- extensions/variants_test.go | 27 +++++++++++++++++++++++++++ functions/dialect_test.go | 24 ++++++++++++++++-------- types/type_derivation.go | 20 ++++++++++++++++++-- types/type_derivation_test.go | 35 +++++++++++++++++++++++++++++++++++ 5 files changed, 99 insertions(+), 13 deletions(-) diff --git a/extensions/variants.go b/extensions/variants.go index 5c2b277..06f1800 100644 --- a/extensions/variants.go +++ b/extensions/variants.go @@ -170,6 +170,9 @@ func matchArguments(nullability NullabilityHandling, paramTypeList FuncParameter return false, nil } } + if HasSyncParams(funcDefArgList) { + return types.AreSyncTypeParametersMatching(funcDefArgList, actualTypes), nil + } return true, nil } @@ -193,9 +196,6 @@ func matchArgumentAtCommon(actualType types.Type, argPos int, nullability Nullab return false, nil } - if HasSyncParams(funcDefArgList) { - return false, fmt.Errorf("%w: function has sync params", substraitgo.ErrNotImplemented) - } // if argPos is >= len(funcDefArgList) than last funcDefArg type should be considered for type match // already checked for parameter in range above (considering variadic) so no need to check again for variadic var funcDefArg types.FuncDefArgType diff --git a/extensions/variants_test.go b/extensions/variants_test.go index 1c5e2a3..89093f2 100644 --- a/extensions/variants_test.go +++ b/extensions/variants_test.go @@ -7,7 +7,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/substrait-io/substrait" + "github.com/substrait-io/substrait-go/v3/expr" "github.com/substrait-io/substrait-go/v3/extensions" + parser2 "github.com/substrait-io/substrait-go/v3/testcases/parser" "github.com/substrait-io/substrait-go/v3/types" "github.com/substrait-io/substrait-go/v3/types/integer_parameters" "github.com/substrait-io/substrait-go/v3/types/parser" @@ -206,3 +209,27 @@ func TestHasSyncParams(t *testing.T) { }) } } + +func TestMatchWithSyncParams(t *testing.T) { + testFiles := []string{ + "tests/cases/arithmetic_decimal/bitwise_or.test", + "tests/cases/arithmetic_decimal/bitwise_xor.test", + "tests/cases/arithmetic_decimal/bitwise_and.test", + } + for _, testFile := range testFiles { + fs := substrait.GetSubstraitTestsFS() + testFile, err := parser2.ParseTestCaseFileFromFS(fs, testFile) + require.NoError(t, err) + require.NotNil(t, testFile) + assert.Len(t, testFile.TestCases, 14) + + reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection) + for _, tc := range testFile.TestCases { + t.Run(tc.FuncName, func(t *testing.T) { + invocation, err := tc.GetScalarFunctionInvocation(®) + require.NoError(t, err) + require.Equal(t, tc.ID(), invocation.ID()) + }) + } + } +} diff --git a/functions/dialect_test.go b/functions/dialect_test.go index c5e0c7a..7e78033 100644 --- a/functions/dialect_test.go +++ b/functions/dialect_test.go @@ -592,7 +592,7 @@ dependencies: http://localhost/sample.yaml supported_types: dec: - sql_type_name: INTEGER + sql_type_name: NUMBER scalar_functions: - name: arithmetic.func_testsync supported_kernels: @@ -604,21 +604,29 @@ scalar_functions: funcRegistry := NewFunctionRegistry(&c) localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry) - int32Nullable := &types.Int32Type{Nullability: types.NullabilityNullable} + argTypes := []types.Type{ + &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2}, + &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2}, + } + argTypesWithMismatchedParams := []types.Type{ + &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 20, Scale: 2}, + &types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2}, + } fv := localRegistry.GetScalarFunctions(LocalFunctionName("func_testsync"), 2) - argTypes := []types.Type{int32Nullable, int32Nullable} require.Len(t, fv, 1) - _, err := fv[0].Match(argTypes) - require.Error(t, err) - require.ErrorContains(t, err, "function has sync param") + isMatch, err := fv[0].Match(argTypes) + require.NoError(t, err) + require.True(t, isMatch) + isMatch, err = fv[0].Match(argTypesWithMismatchedParams) + require.NoError(t, err) + require.False(t, isMatch) // test MatchAt for pos, typ := range argTypes { _, err = fv[0].MatchAt(typ, pos) - require.Error(t, err) - require.ErrorContains(t, err, "function has sync param") + require.NoError(t, err) } } diff --git a/types/type_derivation.go b/types/type_derivation.go index d5788a0..bbfa11c 100644 --- a/types/type_derivation.go +++ b/types/type_derivation.go @@ -307,8 +307,7 @@ type SymbolInfo struct { Value any } -func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentTypes []Type) (Type, error) { - // Add parameterized parameters of arguments to symbol table +func buildTypeParametersNameValueMap(funcParameters []FuncDefArgType, argumentTypes []Type) (map[string]any, error) { symbolTable := make(map[string]any) for i, p := range funcParameters { paramNames := p.GetParameterizedParams() @@ -320,12 +319,24 @@ func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentT for j, param := range paramNames { if intParam, ok := param.(*integer_parameters.VariableIntParam); ok { name := string(*intParam) + if existingValue, ok := symbolTable[name]; ok && existingValue != paramValues[j] { + return nil, fmt.Errorf("sync parameters %s has conflicting values: %v and %v", name, existingValue, paramValues[j]) + } symbolTable[name] = paramValues[j] continue } } } } + return symbolTable, nil +} + +func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentTypes []Type) (Type, error) { + // Build a symbol table of parameterized parameters of arguments + symbolTable, err := buildTypeParametersNameValueMap(funcParameters, argumentTypes) + if err != nil { + return nil, err + } // Evaluate assignments for _, a := range m.Assignments { @@ -359,3 +370,8 @@ func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentT func (m *OutputDerivation) WithParameters([]interface{}) (Type, error) { panic("WithParameters not to be called") } + +func AreSyncTypeParametersMatching(funcParameters []FuncDefArgType, argumentTypes []Type) bool { + _, err := buildTypeParametersNameValueMap(funcParameters, argumentTypes) + return err == nil +} diff --git a/types/type_derivation_test.go b/types/type_derivation_test.go index 0d45e37..a9b156d 100644 --- a/types/type_derivation_test.go +++ b/types/type_derivation_test.go @@ -232,3 +232,38 @@ func Test_getBinaryOpType(t *testing.T) { }) } } + +func TestAreSyncTypeParametersMatching(t *testing.T) { + tests := []struct { + name string + parameters []string + arguments []string + inSync bool + }{ + {"singleVarchar", []string{"varchar"}, []string{"varchar<8>"}, true}, + {"singleDec", []string{"decimal"}, []string{"decimal<20,5>"}, true}, + {"SingleDecSingleParam", []string{"decimal"}, []string{"decimal<20,5>"}, true}, + {"SingleDecScaleParam", []string{"decimal<25, S1>"}, []string{"decimal<25,9>"}, true}, + {"NonSyncVarcharParams", []string{"varchar", "varchar"}, []string{"varchar<9>", "varchar<8>"}, true}, + {"SyncVarcharParams", []string{"varchar", "varchar"}, []string{"varchar<9>", "varchar<9>"}, true}, + {"SyncVarcharParamsNeg", []string{"varchar", "varchar"}, []string{"varchar<9>", "varchar<8>"}, false}, + {"NonSyncVarcharParams", []string{"decimal", "decimal"}, []string{"decimal<18,0>", "decimal<27,0>"}, true}, + {"SyncScaleParam", []string{"decimal<38, S1>", "decimal<38, S1>"}, []string{"decimal<38,10>", "decimal<38,10>"}, true}, + {"SyncScalParamNeg", []string{"decimal<38, S1>", "decimal<38, S1>"}, []string{"decimal<38,10>", "decimal<38,12>"}, false}, + {"SyncPrecisionParam", []string{"decimal", "decimal"}, []string{"decimal<18,0>", "decimal<18,0>"}, true}, + {"SyncPrecisionParamNeg", []string{"decimal", "decimal"}, []string{"decimal<18,0>", "decimal<20,0>"}, false}, + {"NonSyncDecParams", []string{"decimal", "decimal"}, []string{"decimal<18,5>", "decimal<38,9>"}, true}, + {"SyncDecParams", []string{"decimal", "decimal"}, []string{"decimal<18,5>", "decimal<18,9>"}, true}, + {"SyncDecParamsBadScale", []string{"decimal", "decimal"}, []string{"decimal<18,5>", "decimal<18,9>"}, false}, + {"SyncDecParamsBadPrecision", []string{"decimal", "decimal"}, []string{"decimal<19,5>", "decimal<18,5>"}, false}, + {"SyncDecParamsNeg", []string{"decimal", "decimal"}, []string{"decimal<19,5>", "decimal<18,9>"}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + funcParameters := parseFuncParameters(t, tt.parameters) + funcArguments := parseFuncArguments(t, tt.arguments) + assert.Equalf(t, tt.inSync, types.AreSyncTypeParametersMatching(funcParameters, funcArguments), "AreSyncTypeParametersMatching(%v)", tt.name) + }) + } +}