diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 5200b069..f12f7736 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -155,31 +155,52 @@ end # we just want to record the number of arguments. DynamicDiff.D(f::ArgumentRecorder, ::Integer) = f +function check_combiner_applicability( + @nospecialize(combiner), + @nospecialize(dummy_expressions), + @nospecialize(dummy_valid_vectors), + @nospecialize(dummy_params) +) + base_error_msg = ( + "Your template structure's `combine` function must accept", + "\t1. A `NamedTuple` of `ComposableExpression`s (or `ArgumentRecorder`s)", + "\t2. A tuple of `ValidVector`s", + ) + + if dummy_params === nothing + if !applicable(combiner, dummy_expressions, dummy_valid_vectors) + throw(ArgumentError(join(base_error_msg, '\n'))) + end + else + if !applicable(combiner, dummy_expressions, dummy_valid_vectors, dummy_params) + throw(ArgumentError(join((base_error_msg..., "\t3. A `ParamVector`"), '\n'))) + end + end + return nothing +end + """Infers number of features used by each subexpression, by passing in test data.""" -function infer_variable_constraints(::Val{K}, num_parameters, combiner::F) where {K,F} +function infer_variable_constraints( + ::Val{K}, num_parameters, @nospecialize(combiner) +) where {K} variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K)) - # Now, we need to evaluate the `combine` function to see how many - # features are used for each function call. If unset, we record it. - # If set, we validate. inner = Fix{1}(_record_composable_expression!, variable_constraints) + dummy_expressions = NamedTuple{K}(map(k -> ArgumentRecorder(Fix{1}(inner, Val(k))), K)) + dummy_valid_vectors = Base.Iterators.repeated(ValidVector(ones(Float64, 1), true)) + dummy_params = + num_parameters === nothing ? nothing : ParamVector(ones(Float64, num_parameters)) - # This is like the (; f, g) in the structure function - _dummy_expressions = NamedTuple{K}(map(k -> ArgumentRecorder(Fix{1}(inner, Val(k))), K)) - - # This part is like the (x1, x2, x3) in the structure function - _dummy_valid_vectors = Base.Iterators.repeated(ValidVector(ones(Float64, 1), true)) + check_combiner_applicability( + combiner, dummy_expressions, dummy_valid_vectors, dummy_params + ) - # This part is like the params in the structure function - _extra_args = if num_parameters === nothing - () + # Actually call the combiner + if dummy_params === nothing + combiner(dummy_expressions, dummy_valid_vectors) else - (ParamVector(ones(Float64, num_parameters)),) + combiner(dummy_expressions, dummy_valid_vectors, dummy_params) end - # Now, we actually call the structure function - combiner(_dummy_expressions, _dummy_valid_vectors, _extra_args...) - # TODO: Add a helpful error message for the user if they forget to set `num_parameters` - inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints))) if any(==(-1), values(inferred)) failed_keys = filter(k -> inferred[k] == -1, K)