Skip to content

Commit

Permalink
[Nonlinear] fix splatting with a univariate operator (#2221)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Jun 22, 2023
1 parent 0214039 commit 95036fb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/Nonlinear/parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ function _parse_expression(stack, data, expr, x, parent_index)
if length(x.args) == 2 && !isexpr(x.args[2], :...)
_parse_univariate_expression(stack, data, expr, x, parent_index)
else
# The call is either n-ary, or it is a splat, in which case we
# cannot tell just yet whether the expression is unary or nary.
# Punt to multivariate and try to recover later.
_parse_multivariate_expression(stack, data, expr, x, parent_index)
end
elseif isexpr(x, :comparison)
Expand Down Expand Up @@ -177,8 +180,15 @@ function _parse_multivariate_expression(
@assert isexpr(x, :call)
id = get(data.operators.multivariate_operator_to_id, x.args[1], nothing)
if id === nothing
@assert x.args[1] in data.operators.comparison_operators
_parse_inequality_expression(stack, data, expr, x, parent_index)
if haskey(data.operators.univariate_operator_to_id, x.args[1])
# It may also be a unary variate operator with splatting.
_parse_univariate_expression(stack, data, expr, x, parent_index)
elseif x.args[1] in data.operators.comparison_operators
# Or it may be a binary (in)equality operator.
_parse_inequality_expression(stack, data, expr, x, parent_index)
else
throw(MOI.UnsupportedNonlinearOperator(x.args[1]))
end
return
end
push!(expr.nodes, Node(NODE_CALL_MULTIVARIATE, id, parent_index))
Expand Down
19 changes: 19 additions & 0 deletions test/Nonlinear/Nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,25 @@ function test_ListOfSupportedNonlinearOperators()
return
end

function test_parse_univariate_splatting()
model = MOI.Nonlinear.Model()
MOI.Nonlinear.register_operator(model, :f, 1, x -> 2x)
x = [MOI.VariableIndex(1)]
@test MOI.Nonlinear.parse_expression(model, :(f($x...))) ==
MOI.Nonlinear.parse_expression(model, :(f($(x[1]))))
return
end

function test_parse_unsupported_operator()
model = MOI.Nonlinear.Model()
x = [MOI.VariableIndex(1)]
@test_throws(
MOI.UnsupportedNonlinearOperator(:f),
MOI.Nonlinear.parse_expression(model, :(f($x...))),
)
return
end

end

TestNonlinear.runtests()
Expand Down

0 comments on commit 95036fb

Please sign in to comment.