diff --git a/go.mod b/go.mod index 217a0bd83..d580636c6 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295 github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1 github.com/fatih/color v1.13.0 - github.com/flyteorg/flyteidl v1.5.13 + github.com/flyteorg/flyteidl v1.5.16 github.com/flyteorg/flyteplugins v1.1.30 github.com/flyteorg/flytestdlib v1.0.24 github.com/ghodss/yaml v1.0.0 diff --git a/go.sum b/go.sum index e5f974662..b058b325e 100644 --- a/go.sum +++ b/go.sum @@ -242,8 +242,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.5.13 h1:IQ2Cw+u36ew3BPyRDAcHdzc/GyNEOXOxhKy9jbS4hbo= -github.com/flyteorg/flyteidl v1.5.13/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= +github.com/flyteorg/flyteidl v1.5.16 h1:S70wD7K99nKHZxmo8U16Jjhy1kZwoBh5ZQhZf3/6MPU= +github.com/flyteorg/flyteidl v1.5.16/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og= github.com/flyteorg/flyteplugins v1.1.30 h1:AVqS6Eb9Nr9Z3Mb3CtP04ffAVS9LMx5Q1Z7AyFFk/e0= github.com/flyteorg/flyteplugins v1.1.30/go.mod h1:FujFQdL/f9r1HvFR81JCiNYusDy9F0lExhyoyMHXXbg= github.com/flyteorg/flytestdlib v1.0.24 h1:jDvymcjlsTRCwOtxPapro0WZBe3isTz+T3Tiq+mZUuk= diff --git a/pkg/compiler/validators/condition.go b/pkg/compiler/validators/condition.go index a70c5dcb2..e4f4d6753 100644 --- a/pkg/compiler/validators/condition.go +++ b/pkg/compiler/validators/condition.go @@ -15,6 +15,10 @@ func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operan } else if operand.GetPrimitive() != nil { // no validation literalType = literalTypeForPrimitive(operand.GetPrimitive()) + } else if operand.GetScalar().GetPrimitive() != nil { + literalType = literalTypeForPrimitive(operand.GetPrimitive()) + } else if operand.GetScalar().GetNoneType() != nil { + literalType = &flyte.LiteralType{Type: &flyte.LiteralType_Simple{Simple: flyte.SimpleType_NONE}} } else if len(operand.GetVar()) > 0 { if node.GetInterface() != nil { if param, paramOk := validateInputVar(node, operand.GetVar(), requireParamType, errs.NewScope()); paramOk { @@ -41,7 +45,10 @@ func ValidateBooleanExpression(w c.WorkflowBuilder, node c.NodeBuilder, expr *fl expr.GetComparison().GetRightValue(), requireParamType, errs.NewScope()) op2Type, op2Valid := validateOperand(node, "LeftValue", expr.GetComparison().GetLeftValue(), requireParamType, errs.NewScope()) - if op1Valid && op2Valid && op1Type != nil && op2Type != nil { + // Valid expression + // 1. Both operands are primitive types and have the same types. + // 2. One of the operands is the None type. + if op1Valid && op2Valid && op1Type != nil && op2Type != nil && op1Type.GetSimple() != flyte.SimpleType_NONE && op2Type.GetSimple() != flyte.SimpleType_NONE { if op1Type.String() != op2Type.String() { errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue", op1Type.String(), op2Type.String())) diff --git a/pkg/controller/nodes/branch/comparator.go b/pkg/controller/nodes/branch/comparator.go index 4fc4f2224..5a05b2001 100644 --- a/pkg/controller/nodes/branch/comparator.go +++ b/pkg/controller/nodes/branch/comparator.go @@ -71,9 +71,21 @@ var perTypeComparators = map[string]comparators{ }, } -func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { - lValueType := reflect.TypeOf(lValue.Value) - rValueType := reflect.TypeOf(rValue.Value) +func Evaluate(lValue *core.Scalar, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetNoneType() != nil || rValue.GetNoneType() != nil { + lIsNone := lValue.GetNoneType() != nil + rIsNone := rValue.GetNoneType() != nil + switch op { + case core.ComparisonExpression_EQ: + return lIsNone == rIsNone, nil + case core.ComparisonExpression_NEQ: + return lIsNone != rIsNone, nil + default: + return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between nil and non-nil values with operator [%v] is not supported. lVal[%v]:rVal[%v]", op, lValue, rValue) + } + } + lValueType := reflect.TypeOf(lValue.GetPrimitive().Value) + rValueType := reflect.TypeOf(rValue.GetPrimitive().Value) if lValueType != rValueType { return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between different primitives types. lVal[%v]:rVal[%v]", lValueType, rValueType) } @@ -90,50 +102,50 @@ func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.Comparison if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[GT] not defined for boolean operands.") } - return comps.gt(lValue, rValue), nil + return comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_GTE: if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[GTE] not defined for boolean operands.") } - return comps.eq(lValue, rValue) || comps.gt(lValue, rValue), nil + return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_LT: if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[LT] not defined for boolean operands.") } - return !(comps.gt(lValue, rValue) || comps.eq(lValue, rValue)), nil + return !(comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive())), nil case core.ComparisonExpression_LTE: if isBoolean { return false, errors.Errorf(ErrorCodeMalformedBranch, "[LTE] not defined for boolean operands.") } - return !comps.gt(lValue, rValue), nil + return !comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_EQ: - return comps.eq(lValue, rValue), nil + return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil case core.ComparisonExpression_NEQ: - return !comps.eq(lValue, rValue), nil + return !comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil } return false, errors.Errorf(ErrorCodeMalformedBranch, "Unsupported operator type in Propeller. System error.") } -func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { - if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive.") +func Evaluate1(lValue *core.Scalar, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { + if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue) } - return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op) + return Evaluate(lValue, rValue.GetScalar(), op) } -func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) { - if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") +func Evaluate2(lValue *core.Literal, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) { + if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue) } - return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op) + return Evaluate(lValue.GetScalar(), rValue, op) } func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) { - if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.") + if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue) } - if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive") + if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue) } - return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op) + return Evaluate(lValue.GetScalar(), rValue.GetScalar(), op) } diff --git a/pkg/controller/nodes/branch/comparator_test.go b/pkg/controller/nodes/branch/comparator_test.go index d1c120ef4..21f60272a 100644 --- a/pkg/controller/nodes/branch/comparator_test.go +++ b/pkg/controller/nodes/branch/comparator_test.go @@ -11,8 +11,8 @@ import ( ) func TestEvaluate_int(t *testing.T) { - p1 := coreutils.MustMakePrimitive(1) - p2 := coreutils.MustMakePrimitive(2) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -82,8 +82,8 @@ func TestEvaluate_int(t *testing.T) { } func TestEvaluate_float(t *testing.T) { - p1 := coreutils.MustMakePrimitive(1.0) - p2 := coreutils.MustMakePrimitive(2.0) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -153,8 +153,8 @@ func TestEvaluate_float(t *testing.T) { } func TestEvaluate_string(t *testing.T) { - p1 := coreutils.MustMakePrimitive("a") - p2 := coreutils.MustMakePrimitive("b") + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("a")}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("b")}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -224,8 +224,8 @@ func TestEvaluate_string(t *testing.T) { } func TestEvaluate_datetime(t *testing.T) { - p1 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC)) - p2 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC)) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC))}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC))}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -295,8 +295,8 @@ func TestEvaluate_datetime(t *testing.T) { } func TestEvaluate_duration(t *testing.T) { - p1 := coreutils.MustMakePrimitive(10 * time.Second) - p2 := coreutils.MustMakePrimitive(11 * time.Second) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(10 * time.Second)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(11 * time.Second)}} { // p1 > p2 = false b, err := Evaluate(p1, p2, core.ComparisonExpression_GT) @@ -366,8 +366,8 @@ func TestEvaluate_duration(t *testing.T) { } func TestEvaluate_boolean(t *testing.T) { - p1 := coreutils.MustMakePrimitive(true) - p2 := coreutils.MustMakePrimitive(false) + p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(true)}} + p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(false)}} f := func(op core.ComparisonExpression_Operator) { // GT/LT = false msg := fmt.Sprintf("Evaluating: [%s]", op.String()) diff --git a/pkg/controller/nodes/branch/evaluator.go b/pkg/controller/nodes/branch/evaluator.go index fe6d7edac..e81a127fb 100644 --- a/pkg/controller/nodes/branch/evaluator.go +++ b/pkg/controller/nodes/branch/evaluator.go @@ -20,31 +20,53 @@ const ErrorCodeFailedFetchOutputs = "FailedFetchOutputs" func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.LiteralMap) (bool, error) { var lValue *core.Literal var rValue *core.Literal - var lPrim *core.Primitive - var rPrim *core.Primitive + var lPrim *core.Scalar + var rPrim *core.Scalar if expr.GetLeftValue().GetPrimitive() == nil { - if nodeInputs == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + if expr.GetLeftValue().GetScalar().GetNoneType() != nil { + lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetLeftValue().GetScalar()}} + } else if expr.GetLeftValue().GetScalar().GetUnion() != nil { + lValue = expr.GetLeftValue().GetScalar().GetUnion().GetValue() + } else { + if nodeInputs == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + input := nodeInputs.Literals[expr.GetLeftValue().GetVar()] + if input.GetScalar().GetUnion().GetValue() != nil { + lValue = input.GetScalar().GetUnion().GetValue() + } else { + lValue = input + } } - lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()] if lValue == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) } } else { - lPrim = expr.GetLeftValue().GetPrimitive() + lPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetLeftValue().GetPrimitive()}} } if expr.GetRightValue().GetPrimitive() == nil { - if nodeInputs == nil { - return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + if expr.GetRightValue().GetScalar().GetNoneType() != nil { + rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetRightValue().GetScalar()}} + } else if expr.GetRightValue().GetScalar().GetUnion() != nil { + rValue = expr.GetRightValue().GetScalar().GetUnion().GetValue() + } else { + if nodeInputs == nil { + return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar()) + } + input := nodeInputs.Literals[expr.GetRightValue().GetVar()] + if input.GetScalar().GetUnion().GetValue() != nil { + rValue = input.GetScalar().GetUnion().GetValue() + } else { + rValue = input + } } - rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()] if rValue == nil { return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar()) } } else { - rPrim = expr.GetRightValue().GetPrimitive() + rPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetRightValue().GetPrimitive()}} } if lValue != nil && rValue != nil { diff --git a/pkg/controller/nodes/branch/evaluator_test.go b/pkg/controller/nodes/branch/evaluator_test.go index 895b73194..a64031ae2 100644 --- a/pkg/controller/nodes/branch/evaluator_test.go +++ b/pkg/controller/nodes/branch/evaluator_test.go @@ -56,6 +56,16 @@ func createUnaryConjunction(l *core.ComparisonExpression, op core.ConjunctionExp } } +func getNoneOperand() *core.Operand { + return &core.Operand{ + Val: &core.Operand_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_NoneType{NoneType: &core.Void{}}, + }, + }, + } +} + func TestEvaluateComparison(t *testing.T) { t.Run("ComparePrimitives", func(t *testing.T) { // Compare primitives @@ -100,6 +110,80 @@ func TestEvaluateComparison(t *testing.T) { assert.NoError(t, err) assert.False(t, v) }) + t.Run("CompareNoneAndLiteral", func(t *testing.T) { + // Compare lVal -> None and rVal -> literal + exp := &core.ComparisonExpression{ + LeftValue: getNoneOperand(), + Operator: core.ComparisonExpression_EQ, + RightValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: coreutils.MustMakePrimitive(1), + }, + }, + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.False(t, v) + }) + t.Run("CompareLiteralAndNone", func(t *testing.T) { + // Compare lVal -> literal and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Primitive{ + Primitive: coreutils.MustMakePrimitive(1), + }, + }, + Operator: core.ComparisonExpression_NEQ, + RightValue: getNoneOperand(), + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.True(t, v) + }) + t.Run("CompareUnionLiteralAndNone", func(t *testing.T) { + // Compare lVal -> literal and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: &core.Operand{ + Val: &core.Operand_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}}, + }, + }, + }, + }, + }, + }, + Operator: core.ComparisonExpression_NEQ, + RightValue: getNoneOperand(), + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.True(t, v) + }) + t.Run("CompareNoneAndNone", func(t *testing.T) { + // Compare lVal -> None and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: getNoneOperand(), + Operator: core.ComparisonExpression_EQ, + RightValue: getNoneOperand(), + } + v, err := EvaluateComparison(exp, nil) + assert.NoError(t, err) + assert.True(t, v) + }) + t.Run("CompareNoneAndNoneWithError", func(t *testing.T) { + // Compare lVal -> None and rVal -> None + exp := &core.ComparisonExpression{ + LeftValue: getNoneOperand(), + Operator: core.ComparisonExpression_GTE, + RightValue: getNoneOperand(), + } + _, err := EvaluateComparison(exp, nil) + assert.Error(t, err) + }) t.Run("CompareLiteralAndPrimitive", func(t *testing.T) { // Compare lVal -> literal and rVal -> primitive