Skip to content

Commit

Permalink
feat: dynamic autodiff integration
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 12, 2024
1 parent 65111cf commit edcb9ed
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicDiff = "7317a516-7a03-4707-b902-c6dba1468ba0"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Expand Down Expand Up @@ -48,6 +49,7 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicDiff = "0.2"
DynamicExpressions = "1.6.0"
DynamicQuantities = "1"
Enzyme = "0.12, 0.13"
Expand Down
1 change: 1 addition & 0 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ using DynamicExpressions: with_type_parameters
LogitDistLoss,
QuantileLoss,
LogCoshLoss
using DynamicAutodiff: D
using Compat: @compat, Fix

#! format: off
Expand Down
7 changes: 2 additions & 5 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module TemplateExpressionModule

using Random: AbstractRNG
using Compat: Fix
using DynamicAutodiff: DynamicAutodiff
using DispatchDoctor: @unstable, @stable
using StyledStrings: @styled_str, annotatedstring
using DynamicExpressions:
Expand Down Expand Up @@ -39,9 +40,6 @@ using ..MutateModule: MutateModule as MM
using ..PopMemberModule: PopMember
using ..ComposableExpressionModule: ComposableExpression, ValidVector

# TODO: Modify `D` once DynamicAutodiff is registered
# import DynamicAutodiff: D

"""
TemplateStructure{K,E,NF} <: Function
Expand Down Expand Up @@ -94,10 +92,9 @@ struct ArgumentRecorder{F} <: Function
end
(f::ArgumentRecorder)(args...) = f.f(args...)

# TODO: Modify `D` once DynamicAutodiff is registered
# We pass through the derivative operators, since
# we just want to record the number of arguments.
# DA.D(f::ArgumentRecorder, _::Integer) = f
DynamicAutodiff.D(f::ArgumentRecorder, ::Integer) = f

"""Infers number of features used by each subexpression, by passing in test data."""
function infer_variable_constraints(::Val{K}, combiner::F) where {K,F}
Expand Down

0 comments on commit edcb9ed

Please sign in to comment.