Skip to content

Commit

Permalink
fix: function variant match to deal with sync type parameters in the …
Browse files Browse the repository at this point in the history
…function parameters
  • Loading branch information
scgkiran committed Dec 17, 2024
1 parent 36dd6de commit d0f96f5
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 13 deletions.
6 changes: 3 additions & 3 deletions extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ func matchArguments(nullability NullabilityHandling, paramTypeList FuncParameter
return false, nil
}
}
if HasSyncParams(funcDefArgList) {
return types.AreSyncTypeParametersMatching(funcDefArgList, actualTypes), nil
}
return true, nil
}

Expand All @@ -193,9 +196,6 @@ func matchArgumentAtCommon(actualType types.Type, argPos int, nullability Nullab
return false, nil
}

if HasSyncParams(funcDefArgList) {
return false, fmt.Errorf("%w: function has sync params", substraitgo.ErrNotImplemented)
}
// if argPos is >= len(funcDefArgList) than last funcDefArg type should be considered for type match
// already checked for parameter in range above (considering variadic) so no need to check again for variadic
var funcDefArg types.FuncDefArgType
Expand Down
27 changes: 27 additions & 0 deletions extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/extensions"
parser2 "github.com/substrait-io/substrait-go/v3/testcases/parser"
"github.com/substrait-io/substrait-go/v3/types"
"github.com/substrait-io/substrait-go/v3/types/integer_parameters"
"github.com/substrait-io/substrait-go/v3/types/parser"
Expand Down Expand Up @@ -206,3 +209,27 @@ func TestHasSyncParams(t *testing.T) {
})
}
}

func TestMatchWithSyncParams(t *testing.T) {
testFiles := []string{
"tests/cases/arithmetic_decimal/bitwise_or.test",
"tests/cases/arithmetic_decimal/bitwise_xor.test",
"tests/cases/arithmetic_decimal/bitwise_and.test",
}
for _, testFile := range testFiles {
fs := substrait.GetSubstraitTestsFS()
testFile, err := parser2.ParseTestCaseFileFromFS(fs, testFile)
require.NoError(t, err)
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, 14)

reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
for _, tc := range testFile.TestCases {
t.Run(tc.FuncName, func(t *testing.T) {
invocation, err := tc.GetScalarFunctionInvocation(&reg)
require.NoError(t, err)
require.Equal(t, tc.ID(), invocation.ID())
})
}
}
}
24 changes: 16 additions & 8 deletions functions/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ dependencies:
http://localhost/sample.yaml
supported_types:
dec:
sql_type_name: INTEGER
sql_type_name: NUMBER
scalar_functions:
- name: arithmetic.func_testsync
supported_kernels:
Expand All @@ -604,21 +604,29 @@ scalar_functions:
funcRegistry := NewFunctionRegistry(&c)
localRegistry := getLocalFunctionRegistry(t, dialectYaml, funcRegistry)

int32Nullable := &types.Int32Type{Nullability: types.NullabilityNullable}
argTypes := []types.Type{
&types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2},
&types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2},
}
argTypesWithMismatchedParams := []types.Type{
&types.DecimalType{Nullability: types.NullabilityNullable, Precision: 20, Scale: 2},
&types.DecimalType{Nullability: types.NullabilityNullable, Precision: 10, Scale: 2},
}

fv := localRegistry.GetScalarFunctions(LocalFunctionName("func_testsync"), 2)

argTypes := []types.Type{int32Nullable, int32Nullable}
require.Len(t, fv, 1)
_, err := fv[0].Match(argTypes)
require.Error(t, err)
require.ErrorContains(t, err, "function has sync param")
isMatch, err := fv[0].Match(argTypes)
require.NoError(t, err)
require.True(t, isMatch)
isMatch, err = fv[0].Match(argTypesWithMismatchedParams)
require.NoError(t, err)
require.False(t, isMatch)

// test MatchAt
for pos, typ := range argTypes {
_, err = fv[0].MatchAt(typ, pos)
require.Error(t, err)
require.ErrorContains(t, err, "function has sync param")
require.NoError(t, err)
}
}

Expand Down
20 changes: 18 additions & 2 deletions types/type_derivation.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,8 +307,7 @@ type SymbolInfo struct {
Value any
}

func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentTypes []Type) (Type, error) {
// Add parameterized parameters of arguments to symbol table
func buildTypeParametersNameValueMap(funcParameters []FuncDefArgType, argumentTypes []Type) (map[string]any, error) {
symbolTable := make(map[string]any)
for i, p := range funcParameters {
paramNames := p.GetParameterizedParams()
Expand All @@ -320,12 +319,24 @@ func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentT
for j, param := range paramNames {
if intParam, ok := param.(*integer_parameters.VariableIntParam); ok {
name := string(*intParam)
if existingValue, ok := symbolTable[name]; ok && existingValue != paramValues[j] {
return nil, fmt.Errorf("sync parameters %s has conflicting values: %v and %v", name, existingValue, paramValues[j])
}
symbolTable[name] = paramValues[j]
continue
}
}
}
}
return symbolTable, nil
}

func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentTypes []Type) (Type, error) {
// Build a symbol table of parameterized parameters of arguments
symbolTable, err := buildTypeParametersNameValueMap(funcParameters, argumentTypes)
if err != nil {
return nil, err
}

// Evaluate assignments
for _, a := range m.Assignments {
Expand Down Expand Up @@ -359,3 +370,8 @@ func (m *OutputDerivation) ReturnType(funcParameters []FuncDefArgType, argumentT
func (m *OutputDerivation) WithParameters([]interface{}) (Type, error) {
panic("WithParameters not to be called")
}

func AreSyncTypeParametersMatching(funcParameters []FuncDefArgType, argumentTypes []Type) bool {
_, err := buildTypeParametersNameValueMap(funcParameters, argumentTypes)
return err == nil
}
35 changes: 35 additions & 0 deletions types/type_derivation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,38 @@ func Test_getBinaryOpType(t *testing.T) {
})
}
}

func TestAreSyncTypeParametersMatching(t *testing.T) {
tests := []struct {
name string
parameters []string
arguments []string
inSync bool
}{
{"singleVarchar", []string{"varchar<L1>"}, []string{"varchar<8>"}, true},
{"singleDec", []string{"decimal<P1, S1>"}, []string{"decimal<20,5>"}, true},
{"SingleDecSingleParam", []string{"decimal<P1, 5>"}, []string{"decimal<20,5>"}, true},
{"SingleDecScaleParam", []string{"decimal<25, S1>"}, []string{"decimal<25,9>"}, true},
{"NonSyncVarcharParams", []string{"varchar<L1>", "varchar<L2>"}, []string{"varchar<9>", "varchar<8>"}, true},
{"SyncVarcharParams", []string{"varchar<L1>", "varchar<L1>"}, []string{"varchar<9>", "varchar<9>"}, true},
{"SyncVarcharParamsNeg", []string{"varchar<L1>", "varchar<L1>"}, []string{"varchar<9>", "varchar<8>"}, false},
{"NonSyncVarcharParams", []string{"decimal<P1, 0>", "decimal<P2, 0>"}, []string{"decimal<18,0>", "decimal<27,0>"}, true},
{"SyncScaleParam", []string{"decimal<38, S1>", "decimal<38, S1>"}, []string{"decimal<38,10>", "decimal<38,10>"}, true},
{"SyncScalParamNeg", []string{"decimal<38, S1>", "decimal<38, S1>"}, []string{"decimal<38,10>", "decimal<38,12>"}, false},
{"SyncPrecisionParam", []string{"decimal<P1, 0>", "decimal<P1, 0>"}, []string{"decimal<18,0>", "decimal<18,0>"}, true},
{"SyncPrecisionParamNeg", []string{"decimal<P1, 0>", "decimal<P1, 0>"}, []string{"decimal<18,0>", "decimal<20,0>"}, false},
{"NonSyncDecParams", []string{"decimal<P1, S1>", "decimal<P2, S2>"}, []string{"decimal<18,5>", "decimal<38,9>"}, true},
{"SyncDecParams", []string{"decimal<P1, S1>", "decimal<P1, S2>"}, []string{"decimal<18,5>", "decimal<18,9>"}, true},
{"SyncDecParamsBadScale", []string{"decimal<P1, S1>", "decimal<P1, S1>"}, []string{"decimal<18,5>", "decimal<18,9>"}, false},
{"SyncDecParamsBadPrecision", []string{"decimal<P1, S1>", "decimal<P1, S2>"}, []string{"decimal<19,5>", "decimal<18,5>"}, false},
{"SyncDecParamsNeg", []string{"decimal<P1, S1>", "decimal<P1, S2>"}, []string{"decimal<19,5>", "decimal<18,9>"}, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
funcParameters := parseFuncParameters(t, tt.parameters)
funcArguments := parseFuncArguments(t, tt.arguments)
assert.Equalf(t, tt.inSync, types.AreSyncTypeParametersMatching(funcParameters, funcArguments), "AreSyncTypeParametersMatching(%v)", tt.name)
})
}
}

0 comments on commit d0f96f5

Please sign in to comment.