Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add literal_pow for composable expression #397

Merged
merged 3 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ComposableExpressionModule

using DispatchDoctor: @unstable
using Compat: Fix
using DynamicExpressions:
AbstractExpression,
Expression,
Expand Down Expand Up @@ -185,9 +186,13 @@ function (ex::AbstractComposableExpression)(
end
function (ex::AbstractComposableExpression{T})() where {T}
X = Matrix{T}(undef, 0, 1) # Value is irrelevant as it won't be used
out, _ = eval_tree_array(ex, X) # TODO: The valid is not used; not sure how to incorporate
return only(out)::T
out, complete = eval_tree_array(ex, X) # TODO: The valid is not used; not sure how to incorporate
y = only(out)
return complete ? y::T : nan(y)::T
end
nan(::T) where {T<:AbstractFloat} = convert(T, NaN)
nan(x) = x

function (ex::AbstractComposableExpression)(
x::AbstractComposableExpression, _xs::Vararg{AbstractComposableExpression,N}
) where {N}
Expand Down Expand Up @@ -239,6 +244,9 @@ for op in (
Base.$(op)(x::Number, y::ValidVector) = apply_operator(Base.$(op), x, y)
end
end
function Base.literal_pow(::typeof(^), x::ValidVector, ::Val{p}) where {p}
return apply_operator(Fix{1}(Fix{3}(Base.literal_pow, Val(p)), ^), x)
end

for op in (
:sin, :cos, :tan, :sinh, :cosh, :tanh, :asin, :acos,
Expand Down
39 changes: 39 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,45 @@ end
@test expr(X) ≈ [1.0, 2.0] .- sin.([3.0, 4.0] .- [5.0, 6.0]) .+ 2.5
end

@testitem "Test literal_pow with ValidVector" tags = [:part2] begin
using SymbolicRegression: ValidVector

# Test with valid data
x = ValidVector([2.0, 3.0, 4.0], true)

# Test literal_pow with different powers
@test (x^2).x ≈ [4.0, 9.0, 16.0]
@test (x^3).x ≈ [8.0, 27.0, 64.0]

# And explicitly
@test Base.literal_pow(^, x, Val(2)).x ≈ [4.0, 9.0, 16.0]
@test Base.literal_pow(^, x, Val(3)).x ≈ [8.0, 27.0, 64.0]

# Test with invalid data
invalid_x = ValidVector([2.0, 3.0, 4.0], false)
@test (invalid_x^2).valid == false
@test Base.literal_pow(^, invalid_x, Val(2)).valid == false
end

@testitem "Test nan behavior with argument-less expressions" tags = [:part2] begin
using SymbolicRegression
using DynamicExpressions: OperatorEnum, Node

operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
variable_names = ["x1", "x2"]

# Test with floating point
c1 = ComposableExpression(Node{Float64}(; val=3.0); operators, variable_names)
invalid_const = (c1 / c1 - 1) / (c1 / c1 - 1) # Creates 0/0
@test isnan(invalid_const())
@test typeof(invalid_const()) === Float64

# Test with integer constant
c2 = ComposableExpression(Node{Int}(; val=0); operators, variable_names)
@test c2() == 0
@test typeof(c2()) === Int
end

@testitem "Test higher-order derivatives of safe_log with DynamicDiff" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression: D, safe_log, ValidVector
Expand Down
Loading