From fdefd9a0109c08bc6ea83b9fe075e4a675548ba8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 10 Dec 2024 20:49:31 +0000 Subject: [PATCH 1/2] refactor: reduce allocations from type assertion --- src/InterfaceDynamicExpressions.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 28a284f72..b30957064 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -79,9 +79,8 @@ which speed up evaluation significantly. function expected_array_type(X::AbstractArray, ::Type) return typeof(similar(X, axes(X, 2))) end -function expected_array_type(X::AbstractArray, ::Type, ::Val{:eval_grad_tree_array}) - return typeof(X) -end +expected_array_type(X::AbstractArray, ::Type, ::Val{:eval_grad_tree_array}) = typeof(X) +expected_array_type(::Matrix{T}, ::Type) where {T} = Vector{T} """ eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions, direction::Int) From 652ea0b7330e1745e523fa8462ebabcdba1b8f24 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Tue, 10 Dec 2024 20:49:54 -0800 Subject: [PATCH 2/2] fix: ambiguity in TemplateExpression --- src/TemplateExpression.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 484a4fdb2..889af9e4e 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -372,7 +372,10 @@ end end end ) -@unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:TemplateExpression}) = Any +@unstable begin + IDE.expected_array_type(::AbstractArray, ::Type{<:TemplateExpression}) = Any + IDE.expected_array_type(::Matrix{T}, ::Type{<:TemplateExpression}) where {T} = Any +end function DA.violates_dimensional_constraints( @nospecialize(tree::TemplateExpression),