Skip to content

Commit

Permalink
refactor: reduce compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 26, 2024
1 parent 7a446b1 commit 26d9b01
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,31 +155,52 @@ end
# we just want to record the number of arguments.
DynamicDiff.D(f::ArgumentRecorder, ::Integer) = f

function check_combiner_applicability(
@nospecialize(combiner),
@nospecialize(dummy_expressions),
@nospecialize(dummy_valid_vectors),
@nospecialize(dummy_params)
)
base_error_msg = (
"Your template structure's `combine` function must accept",
"\t1. A `NamedTuple` of `ComposableExpression`s (or `ArgumentRecorder`s)",
"\t2. A tuple of `ValidVector`s",
)

if dummy_params === nothing
if !applicable(combiner, dummy_expressions, dummy_valid_vectors)
throw(ArgumentError(join(base_error_msg, '\n')))
end
else
if !applicable(combiner, dummy_expressions, dummy_valid_vectors, dummy_params)
throw(ArgumentError(join((base_error_msg..., "\t3. A `ParamVector`"), '\n')))
end
end
return nothing
end

"""Infers number of features used by each subexpression, by passing in test data."""
function infer_variable_constraints(::Val{K}, num_parameters, combiner::F) where {K,F}
function infer_variable_constraints(
::Val{K}, num_parameters, @nospecialize(combiner)
) where {K}
variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K))
# Now, we need to evaluate the `combine` function to see how many
# features are used for each function call. If unset, we record it.
# If set, we validate.
inner = Fix{1}(_record_composable_expression!, variable_constraints)
dummy_expressions = NamedTuple{K}(map(k -> ArgumentRecorder(Fix{1}(inner, Val(k))), K))
dummy_valid_vectors = Base.Iterators.repeated(ValidVector(ones(Float64, 1), true))
dummy_params =
num_parameters === nothing ? nothing : ParamVector(ones(Float64, num_parameters))

# This is like the (; f, g) in the structure function
_dummy_expressions = NamedTuple{K}(map(k -> ArgumentRecorder(Fix{1}(inner, Val(k))), K))

# This part is like the (x1, x2, x3) in the structure function
_dummy_valid_vectors = Base.Iterators.repeated(ValidVector(ones(Float64, 1), true))
check_combiner_applicability(
combiner, dummy_expressions, dummy_valid_vectors, dummy_params
)

# This part is like the params in the structure function
_extra_args = if num_parameters === nothing
()
# Actually call the combiner
if dummy_params === nothing
combiner(dummy_expressions, dummy_valid_vectors)
else
(ParamVector(ones(Float64, num_parameters)),)
combiner(dummy_expressions, dummy_valid_vectors, dummy_params)
end

# Now, we actually call the structure function
combiner(_dummy_expressions, _dummy_valid_vectors, _extra_args...)
# TODO: Add a helpful error message for the user if they forget to set `num_parameters`

inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints)))
if any(==(-1), values(inferred))
failed_keys = filter(k -> inferred[k] == -1, K)
Expand Down

0 comments on commit 26d9b01

Please sign in to comment.