Skip to content

Commit

Permalink
feat: add func type in testcase (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
scgkiran authored Dec 16, 2024
1 parent 5924d58 commit 9e1c860
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
17 changes: 16 additions & 1 deletion testcases/parser/nodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ import (
"github.com/substrait-io/substrait-go/types"
)

type TestFuncType string

const (
ScalarFuncType TestFuncType = "scalar"
AggregateFuncType TestFuncType = "aggregate"
WindowFuncType TestFuncType = "window"
)

type CaseLiteral struct {
Type types.Type
ValueText string
Expand All @@ -17,6 +25,7 @@ type CaseLiteral struct {

type TestFileHeader struct {
Version string
FuncType TestFuncType
IncludedURI string
}

Expand All @@ -31,10 +40,16 @@ type TestCase struct {
Columns [][]expr.Literal
TableName string
ColumnTypes []types.Type
FuncType TestFuncType
}

type TestGroup struct {
Description string
TestCases []*TestCase
}

type TestFile struct {
Header TestFileHeader
Header *TestFileHeader
TestCases []*TestCase
}

Expand Down
12 changes: 9 additions & 3 deletions testcases/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ lt('2016-12-31T13:30:15'::ts, '2017-12-31T13:30:15'::ts) = true::bool
timestampType := &types.TimestampType{Nullability: types.NullabilityUnspecified}
assert.Equal(t, timestampType, testFile.TestCases[0].Args[0].Type)
assert.Equal(t, timestampType, testFile.TestCases[0].Args[1].Type)
assert.Equal(t, ScalarFuncType, testFile.TestCases[0].FuncType)
}

func TestParseDecimalExample(t *testing.T) {
Expand Down Expand Up @@ -240,12 +241,14 @@ sum((9223372036854775806, 1, 1, 1, 1, 10000000000)::i64) [overflow:ERROR] = <!ER
assert.Equal(t, listType, testFile.TestCases[0].AggregateArgs[0].Argument.Value.GetType())
assert.Equal(t, "fp64", testFile.TestCases[0].Result.Type.String())
assert.Equal(t, literal.NewFloat64(2), testFile.TestCases[0].Result.Value)
assert.Equal(t, AggregateFuncType, testFile.TestCases[0].FuncType)

assert.Equal(t, "sum", testFile.TestCases[1].FuncName)
assert.Contains(t, testFile.TestCases[1].GroupDesc, "basic")
assert.Equal(t, testFile.TestCases[1].BaseURI, "extensions/functions_arithmetic.yaml")
assert.Len(t, testFile.TestCases[1].Args, 0)
assert.Len(t, testFile.TestCases[1].AggregateArgs, 1)
assert.Equal(t, AggregateFuncType, testFile.TestCases[1].FuncType)
assert.Equal(t, "i64", testFile.TestCases[1].AggregateArgs[0].ColumnType.String())
assert.Equal(t, newInt64List(9223372036854775806, 1, 1, 1, 1, 10000000000), testFile.TestCases[1].AggregateArgs[0].Argument.Value)
assert.Equal(t, "ERROR", testFile.TestCases[1].Options["overflow"])
Expand Down Expand Up @@ -399,12 +402,16 @@ func TestParseTestWithBadScalarTests(t *testing.T) {
{"add(123::fp32, 2.5E::fp32) = 123::fp32", 18, "no viable alternative at input '2.5E'"},
{"add(123::fp32, 1.4E+::fp32) = 123::fp32", 18, "no viable alternative at input '1.4E'"},
{"add(123::fp32, 3.E.5::fp32) = 123::fp32", 17, "no viable alternative at input '3.E'"},
{"f1((1, 2, 3, 4)::i64) = 10::fp64", 0, "expected scalar testcase based on test file header, but got aggregate function testcase"},
}
for _, test := range tests {
t.Run(test.testCaseStr, func(t *testing.T) {
_, err := ParseTestCasesFromString(header + test.testCaseStr)
require.Error(t, err)
expectedErrorMsg := fmt.Sprintf("Syntax error at line 5:%d: %s", test.position, test.errorMsg)
expectedErrorMsg := test.errorMsg
if test.position > 0 {
expectedErrorMsg = fmt.Sprintf("Syntax error at line 5:%d: %s", test.position, test.errorMsg)
}
assert.Contains(t, err.Error(), expectedErrorMsg)
})
}
Expand All @@ -425,6 +432,7 @@ corr(t1.col0, t2.col1) = 1::fp64`,
},
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(my_col::fp32, col0::fp32) = 1::fp64", "mismatched input 'my_col'"},
{"((20, 20), (-3, -3), (1, 1), (10,10), (5,5)) corr(col0::fp32, column1::fp32) = 1::fp64", "mismatched input 'column1'"},
{"f8('13:01:01.234'::time) = 123::i32", "expected aggregate testcase based on test file header, but got scalar function testcase"},
}
for _, test := range tests {
t.Run(test.testCaseStr, func(t *testing.T) {
Expand All @@ -443,7 +451,6 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) {
{"f1((1, 2, 3, 4)::i64) = 10::fp64"},
{"f1((1, 2, 3, 4)::i16) = 10.0::fp32"},
{"f1((1, 2, 3, 4)::i32) = 10::i64"},
{"f2(1.0::fp32, 2.0::fp64) = -7.0::fp32"},
{"f3(('a', 'b')::string) = 'c'::str"},
{"f4((false, true)::boolean) = false::bool"},
{"f5((1.1, 2.2)::fp32) = 3.3::fp32"},
Expand All @@ -454,7 +461,6 @@ func TestParseAggregateTestWithVariousTypes(t *testing.T) {
{"f6((1.1, 2.2, null)::dec?<38,10>) = 3.3::dec<38,10>"},
{"f8(('1991-01-01', '1991-02-02')::date) = '2001-01-01'::date"},
{"f8(('13:01:01.2345678', '14:01:01.333')::time) = 123456::i64"},
{"f8('13:01:01.234'::time) = 123::i32"},
{"f8(('1991-01-01T01:02:03.456', '1991-01-01T00:00:00')::timestamp) = '1991-01-01T22:33:44'::ts"},
{"f8(('1991-01-01T01:02:03.456+05:30', '1991-01-01T00:00:00+15:30')::tstz) = 23::i32"},
{"f10(('P10Y5M', 'P11Y5M')::interval_year) = 'P21Y10M'::interval_year"},
Expand Down
35 changes: 26 additions & 9 deletions testcases/parser/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type TestCaseVisitor struct {
baseparser.FuncTestCaseParserVisitor
ErrorListener util.VisitErrorListener
literalTypeInContext types.Type
testFuncType TestFuncType
}

func (v *TestCaseVisitor) getLiteralTypeInContext() types.Type {
Expand All @@ -41,7 +42,7 @@ func (v *TestCaseVisitor) Visit(tree antlr.ParseTree) interface{} {
}

func (v *TestCaseVisitor) VisitDoc(ctx *baseparser.DocContext) interface{} {
header := v.Visit(ctx.Header()).(TestFileHeader)
header := v.Visit(ctx.Header()).(*TestFileHeader)
testcases := make([]*TestCase, 0, len(ctx.AllTestGroup()))
for _, testGroup := range ctx.AllTestGroup() {
groupTestCases := v.Visit(testGroup).([]*TestCase)
Expand All @@ -57,27 +58,38 @@ func (v *TestCaseVisitor) VisitDoc(ctx *baseparser.DocContext) interface{} {
}

func (v *TestCaseVisitor) VisitHeader(ctx *baseparser.HeaderContext) interface{} {
return TestFileHeader{
Version: ctx.Version().GetText(),
IncludedURI: v.Visit(ctx.Include()).(string),
header := v.Visit(ctx.Version()).(*TestFileHeader)
header.IncludedURI = v.Visit(ctx.Include()).(string)
return header
}

func (v *TestCaseVisitor) VisitVersion(ctx *baseparser.VersionContext) interface{} {
testFuncType := ScalarFuncType
if ctx.SubstraitAggregateTest() != nil {
testFuncType = AggregateFuncType
}
v.testFuncType = testFuncType
return &TestFileHeader{
Version: ctx.FormatVersion().GetText(),
FuncType: testFuncType,
}
}

func (v *TestCaseVisitor) VisitInclude(ctx *baseparser.IncludeContext) interface{} {
return getRawStringFromStringLiteral(ctx.StringLiteral(0).GetText())
}

type TestGroup struct {
Description string
TestCases []*TestCase
}

func (v *TestCaseVisitor) VisitScalarFuncTestGroup(ctx *baseparser.ScalarFuncTestGroupContext) interface{} {
groupDesc := v.Visit(ctx.TestGroupDescription()).(string)
groupTestCases := make([]*TestCase, 0, len(ctx.AllTestCase()))
if v.testFuncType != ScalarFuncType {
v.ErrorListener.ReportVisitError(fmt.Errorf("expected %v testcase based on test file header, but got scalar function testcase", v.testFuncType))
return groupTestCases
}
for _, tc := range ctx.AllTestCase() {
testcase := v.Visit(tc).(*TestCase)
testcase.GroupDesc = groupDesc
testcase.FuncType = ScalarFuncType
groupTestCases = append(groupTestCases, testcase)
}
return groupTestCases
Expand All @@ -86,9 +98,14 @@ func (v *TestCaseVisitor) VisitScalarFuncTestGroup(ctx *baseparser.ScalarFuncTes
func (v *TestCaseVisitor) VisitAggregateFuncTestGroup(ctx *baseparser.AggregateFuncTestGroupContext) interface{} {
groupDesc := v.Visit(ctx.TestGroupDescription()).(string)
groupTestCases := make([]*TestCase, 0, len(ctx.AllAggFuncTestCase()))
if v.testFuncType != AggregateFuncType {
v.ErrorListener.ReportVisitError(fmt.Errorf("expected %v testcase based on test file header, but got aggregate function testcase", v.testFuncType))
return groupTestCases
}
for _, tc := range ctx.AllAggFuncTestCase() {
testcase := v.Visit(tc).(*TestCase)
testcase.GroupDesc = groupDesc
testcase.FuncType = AggregateFuncType
groupTestCases = append(groupTestCases, testcase)
}
return groupTestCases
Expand Down

0 comments on commit 9e1c860

Please sign in to comment.