Skip to content

Commit

Permalink
test: try to speed up operators test
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 14, 2024
1 parent db47189 commit 235bde6
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,20 @@ end
binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond]
unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu]
options = Options(; binary_operators, unary_operators)

function test_part(tree, Xpart, options)
y, completed = eval_tree_array(tree, Xpart, options)
completed || return nothing
# We capture any warnings about the LoopVectorization not working
local y_turbo
eval_warnings = @capture_err begin
y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true)
end
test_info(@test(y[1] y_turbo[1] && eval_warnings == "")) do
@info T tree X[:, seed] y y_turbo eval_warnings
end
end

for T in (Float32, Float64),
index_bin in 1:length(binary_operators),
index_una in 1:length(unary_operators)
Expand All @@ -154,16 +168,7 @@ end
X = rand(MersenneTwister(0), T, 2, 20)
for seed in 1:20
Xpart = X[:, [seed]]
y, completed = eval_tree_array(tree, Xpart, options)
completed || continue
local y_turbo
# We capture any warnings about the LoopVectorization not working
eval_warnings = @capture_err begin
y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true)
end
test_info(@test y[1] y_turbo[1] && eval_warnings == "") do
@info T tree X[:, seed] y y_turbo eval_warnings
end
test_part(tree, Xpart, options)
end
end
end
Expand Down

0 comments on commit 235bde6

Please sign in to comment.