Skip to content

Commit

Permalink
Merge branch 'master' into cond-operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer authored Oct 27, 2023
2 parents ec19564 + 141987a commit a633600
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 23 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"
[compat]
Compat = "^4.2"
DynamicExpressions = "0.13"
DynamicQuantities = "^0.6.2"
DynamicQuantities = "0.7"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.10, 0.11"
Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[compat]
Documenter = "0.27"
2 changes: 1 addition & 1 deletion src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ end
function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D}
allequal(Base.Fix2(dimension_fallback, D).(v)) || error("Inconsistent units in vector.")
dims = dimension_fallback(first(v), D)
v = ustrip(v)
v = ustrip.(v)
return v, dims
end

Expand Down
49 changes: 28 additions & 21 deletions test/test_units.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SymbolicRegression.DimensionalAnalysisModule:
import DynamicQuantities:
DEFAULT_DIM_BASE_TYPE,
Quantity,
QuantityArray,
SymbolicDimensions,
Dimensions,
@u_str,
Expand Down Expand Up @@ -185,28 +186,34 @@ end
end

@testset "With MLJ" begin
model = SRRegressor(;
binary_operators=[+, *],
unary_operators=[sqrt, cbrt, abs],
early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 6),
)
X = (; x1=randn(128) .* u"kg^3", x2=randn(128) .* u"kg^2")
y = (@. cbrt(ustrip(X.x1)) + sqrt(abs(ustrip(X.x2)))) .* u"kg"
mach = MLJ.machine(model, X, y)
MLJ.fit!(mach)
report = MLJ.report(mach)
best_idx = findfirst(report.losses .< 1e-7)
@test report.complexities[best_idx] == 6
@test any(report.equations[best_idx]) do t
t.degree == 1 && t.op == 2 # cbrt
end
@test any(report.equations[best_idx]) do t
t.degree == 1 && t.op == 1 # safe_sqrt
end
for as_quantity_array in (false, true)
model = SRRegressor(;
binary_operators=[+, *],
unary_operators=[sqrt, cbrt, abs],
early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 6),
)
X = if as_quantity_array
(; x1=randn(128) .* u"kg^3", x2=QuantityArray(randn(128) .* u"kg^2"))
else
(; x1=randn(128) .* u"kg^3", x2=randn(128) .* u"kg^2")
end
y = (@. cbrt(ustrip(X.x1)) + sqrt(abs(ustrip(X.x2)))) .* u"kg"
mach = MLJ.machine(model, X, y)
MLJ.fit!(mach)
report = MLJ.report(mach)
best_idx = findfirst(report.losses .< 1e-7)
@test report.complexities[best_idx] == 6
@test any(report.equations[best_idx]) do t
t.degree == 1 && t.op == 2 # cbrt
end
@test any(report.equations[best_idx]) do t
t.degree == 1 && t.op == 1 # safe_sqrt
end

# Prediction should have same units:
ypred = MLJ.predict(mach; rows=1:3)
@test dimension(ypred[begin]) == dimension(y[begin])
# Prediction should have same units:
ypred = MLJ.predict(mach; rows=1:3)
@test dimension(ypred[begin]) == dimension(y[begin])
end

# Multiple outputs:
model = MultitargetSRRegressor(;
Expand Down

0 comments on commit a633600

Please sign in to comment.