Skip to content

Commit

Permalink
feat: add api to get local function variant given function invocation (
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Dec 20, 2024
1 parent 8e12e1e commit e60cefb
Show file tree
Hide file tree
Showing 11 changed files with 484 additions and 27 deletions.
38 changes: 38 additions & 0 deletions expr/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ func BoundFromProto(b *proto.Expression_WindowFunction_Bound) Bound {
return nil
}

type FunctionInvocation interface {
CompoundName() string
ID() extensions.ID
GetOptions() []*types.FunctionOption
GetArgTypes() []types.Type
}

type ScalarFunction struct {
funcRef uint32
declaration *extensions.ScalarFunctionVariant
Expand Down Expand Up @@ -324,6 +331,25 @@ func (s *ScalarFunction) GetOption(name string) []string {
return nil
}

func (s *ScalarFunction) GetOptions() []*types.FunctionOption { return s.options }

func getArgTypes(args []types.FuncArg) []types.Type {
argTypes := make([]types.Type, len(args))
for i, arg := range args {
switch a := arg.(type) {
case Expression:
argTypes[i] = a.GetType()
case types.Type:
argTypes[i] = a
}
}
return argTypes
}

func (s *ScalarFunction) GetArgTypes() []types.Type {
return getArgTypes(s.args)
}

func (s *ScalarFunction) GetType() types.Type { return s.outputType }
func (s *ScalarFunction) ToProtoFuncArg() *proto.FunctionArgument {
return &proto.FunctionArgument{
Expand Down Expand Up @@ -540,6 +566,12 @@ func (w *WindowFunction) String() string {
return b.String()
}

func (w *WindowFunction) GetOptions() []*types.FunctionOption { return w.options }

func (w *WindowFunction) GetArgTypes() []types.Type {
return getArgTypes(w.args)
}

func (w *WindowFunction) GetType() types.Type { return w.outputType }
func (w *WindowFunction) Equals(other Expression) bool {
rhs, ok := other.(*WindowFunction)
Expand Down Expand Up @@ -807,6 +839,12 @@ func (a *AggregateFunction) GetOption(name string) []string {
return nil
}

func (a *AggregateFunction) GetOptions() []*types.FunctionOption { return a.options }

func (a *AggregateFunction) GetArgTypes() []types.Type {
return getArgTypes(a.args)
}

func (a *AggregateFunction) GetType() types.Type { return a.outputType }

func (a *AggregateFunction) ToProto() *proto.AggregateFunction {
Expand Down
2 changes: 1 addition & 1 deletion extensions/variants.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func validateVariadicBehaviorForMatch(variadicBehavior *VariadicBehavior, actual
// all concrete types must be equal for all variable arguments
firstVariadicArgIdx := max(variadicBehavior.Min-1, 0)
for i := firstVariadicArgIdx; i < len(actualTypes)-1; i++ {
if !actualTypes[i].Equals(actualTypes[i+1]) {
if !actualTypes[i].Equals(actualTypes[i+1].WithNullability(actualTypes[i].GetNullability())) {
return false
}
}
Expand Down
8 changes: 4 additions & 4 deletions extensions/variants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait"
"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/extensions"
"github.com/substrait-io/substrait-go/v3/functions"
parser2 "github.com/substrait-io/substrait-go/v3/testcases/parser"
"github.com/substrait-io/substrait-go/v3/types"
"github.com/substrait-io/substrait-go/v3/types/integer_parameters"
Expand Down Expand Up @@ -229,16 +229,16 @@ func TestMatchWithSyncParams(t *testing.T) {
require.NotNil(t, testFile)
assert.Len(t, testFile.TestCases, testFileInfo.numTests)

reg := expr.NewEmptyExtensionRegistry(&extensions.DefaultCollection)
reg, funcRegistry := functions.NewExtensionAndFunctionRegistries(&extensions.DefaultCollection)
for _, tc := range testFile.TestCases {
t.Run(tc.FuncName, func(t *testing.T) {
switch tc.FuncType {
case parser2.ScalarFuncType:
invocation, err := tc.GetScalarFunctionInvocation(&reg)
invocation, err := tc.GetScalarFunctionInvocation(&reg, funcRegistry)
require.NoError(t, err)
require.Equal(t, tc.ID(), invocation.ID())
case parser2.AggregateFuncType:
invocation, err := tc.GetAggregateFunctionInvocation(&reg)
invocation, err := tc.GetAggregateFunctionInvocation(&reg, funcRegistry)
require.NoError(t, err)
require.Equal(t, tc.ID(), invocation.ID())
}
Expand Down
33 changes: 26 additions & 7 deletions functions/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type dialectImpl struct {
localScalarFunctions map[extensions.ID]*dialectFunctionInfo
localAggregateFunctions map[extensions.ID]*dialectFunctionInfo
localWindowFunctions map[extensions.ID]*dialectFunctionInfo

localTypeRegistry LocalTypeRegistry
}

func (d *dialectImpl) Name() string {
Expand All @@ -52,6 +54,10 @@ func appendVariants[T extensions.FunctionVariant](variants []extensions.Function
}

func (d *dialectImpl) LocalizeFunctionRegistry(registry FunctionRegistry) (LocalFunctionRegistry, error) {
localTypeRegistry, err := d.GetLocalTypeRegistry()
if err != nil {
return nil, err
}
scalarFunctions, err := makeLocalFunctionVariantMapAndSlice(d.localScalarFunctions, registry.GetScalarFunctionsByName, newLocalScalarFunctionVariant)
if err != nil {
return nil, err
Expand All @@ -71,11 +77,13 @@ func (d *dialectImpl) LocalizeFunctionRegistry(registry FunctionRegistry) (Local
allVariants = appendVariants(allVariants, windowFunctions.variantsSlice)

return &localFunctionRegistryImpl{
dialect: d,
scalarFunctions: scalarFunctions.variantsMap,
aggregateFunctions: aggregateFunctions.variantsMap,
windowFunctions: windowFunctions.variantsMap,
allFunctions: allVariants,
dialect: d,
scalarFunctions: scalarFunctions.variantsMap,
aggregateFunctions: aggregateFunctions.variantsMap,
windowFunctions: windowFunctions.variantsMap,
allFunctions: allVariants,
idToLocalFunctionMap: makeLocalFunctionVariantsMap(allVariants),
localTypeRegistry: localTypeRegistry,
}, nil
}

Expand Down Expand Up @@ -151,16 +159,27 @@ func addToSliceMapWithLocalKey[V localFunctionVariant](m map[FunctionName][]V, v
}

func (d *dialectImpl) LocalizeTypeRegistry(TypeRegistry) (LocalTypeRegistry, error) {
if d.localTypeRegistry != nil {
return d.localTypeRegistry, nil
}
typeInfos := make([]typeInfo, 0, len(d.toLocalTypeMap))
for name, info := range d.toLocalTypeMap {
// TODO use registry.GetTypeClasses
typ, err := getTypeFromBaseTypeName(name)
if err != nil {
return nil, fmt.Errorf("%w: unknown type %v", substraitgo.ErrInvalidDialect, name)
}
typeInfos = append(typeInfos, typeInfo{typ: typ, shortName: name, localName: info.SqlTypeName, supportedAsColumn: info.SupportedAsColumn})
typeInfos = append(typeInfos, typeInfo{typ: typ, shortName: typ.ShortString(), localName: info.SqlTypeName, supportedAsColumn: info.SupportedAsColumn})
}
d.localTypeRegistry = NewLocalTypeRegistry(typeInfos)
return d.localTypeRegistry, nil
}

func (d *dialectImpl) GetLocalTypeRegistry() (LocalTypeRegistry, error) {
if d.localTypeRegistry != nil {
return d.localTypeRegistry, nil
}
return NewLocalTypeRegistry(typeInfos), nil
return d.LocalizeTypeRegistry(NewTypeRegistry())
}

func newDialect(name string, reader io.Reader) (Dialect, error) {
Expand Down
3 changes: 2 additions & 1 deletion functions/dialect_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: Apache-2.0

package functions
package functions_test

import (
"strings"
Expand All @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substrait-io/substrait-go/v3/extensions"
. "github.com/substrait-io/substrait-go/v3/functions"
"github.com/substrait-io/substrait-go/v3/types"
)

Expand Down
60 changes: 59 additions & 1 deletion functions/local_functions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
package functions

import "github.com/substrait-io/substrait-go/v3/extensions"
import (
"fmt"

"github.com/substrait-io/substrait-go/v3/expr"
"github.com/substrait-io/substrait-go/v3/extensions"
)

type localFunctionRegistryImpl struct {
dialect Dialect
Expand All @@ -11,6 +16,24 @@ type localFunctionRegistryImpl struct {
windowFunctions map[FunctionName][]*LocalWindowFunctionVariant

allFunctions []extensions.FunctionVariant

idToLocalFunctionMap map[extensions.ID]localFunctionVariant
localTypeRegistry LocalTypeRegistry
}

func makeLocalFunctionVariantsMap(functions []extensions.FunctionVariant) map[extensions.ID]localFunctionVariant {
localFunctionVariants := make(map[extensions.ID]localFunctionVariant)
for _, f := range functions {
switch variant := f.(type) {
case *LocalScalarFunctionVariant:
localFunctionVariants[variant.ID()] = variant
case *LocalAggregateFunctionVariant:
localFunctionVariants[variant.ID()] = variant
case *LocalWindowFunctionVariant:
localFunctionVariants[variant.ID()] = variant
}
}
return localFunctionVariants
}

func (l *localFunctionRegistryImpl) GetAllFunctions() []extensions.FunctionVariant {
Expand All @@ -33,4 +56,39 @@ func (l *localFunctionRegistryImpl) GetWindowFunctions(name FunctionName, numArg
return getFunctionVariantsByCount(getOrEmpty(name, l.windowFunctions), numArgs)
}

func (l *localFunctionRegistryImpl) GetScalarFunctionByInvocation(scalarFuncInvocation *expr.ScalarFunction) (*LocalScalarFunctionVariant, error) {
return getFunctionVariantByInvocation[*LocalScalarFunctionVariant](scalarFuncInvocation, l)
}

func (l *localFunctionRegistryImpl) GetAggregateFunctionByInvocation(aggregateFuncInvocation *expr.AggregateFunction) (*LocalAggregateFunctionVariant, error) {
return getFunctionVariantByInvocation[*LocalAggregateFunctionVariant](aggregateFuncInvocation, l)
}

func (l *localFunctionRegistryImpl) GetWindowFunctionByInvocation(windowFuncInvocation *expr.WindowFunction) (*LocalWindowFunctionVariant, error) {
return getFunctionVariantByInvocation[*LocalWindowFunctionVariant](windowFuncInvocation, l)
}

func getFunctionVariantByInvocation[V localFunctionVariant](invocation expr.FunctionInvocation, registry *localFunctionRegistryImpl) (V, error) {
var zeroV V
f, ok := registry.idToLocalFunctionMap[invocation.ID()]
if !ok {
return zeroV, fmt.Errorf("function variant not found for function: %s", invocation.ID())
}
argTypes := invocation.GetArgTypes()
for i, argType := range argTypes {
_, err := registry.localTypeRegistry.GetLocalTypeFromSubstraitType(argType)
if err != nil {
return zeroV, fmt.Errorf("unsupported substrait type: %v as argument %d in %s", argType, i, invocation.CompoundName())
}
}
for _, option := range invocation.GetOptions() {
for _, value := range option.Preference {
if !f.IsOptionSupported(option.Name, value) {
return zeroV, fmt.Errorf("unsupported option [%s:%s] in function %s", option.Name, value, invocation.CompoundName())
}
}
}
return f.(V), nil
}

var _ LocalFunctionRegistry = &localFunctionRegistryImpl{}
Loading

0 comments on commit e60cefb

Please sign in to comment.