diff --git a/include/suboptimal/autodiff/Variable.h b/include/suboptimal/autodiff/Variable.h index 9b9369c..be17817 100644 --- a/include/suboptimal/autodiff/Variable.h +++ b/include/suboptimal/autodiff/Variable.h @@ -66,6 +66,11 @@ struct Variable { */ static Variable Constant(double value); + /** + * Updates the stored expression graph of the variable + */ + void updateGraph() const; + /** * Updates the value of the variable, traversing the expression tree and updating all expressions and variables this * variable depends on diff --git a/src/autodiff/Expression.cpp b/src/autodiff/Expression.cpp index 7cedfe2..ef969cc 100644 --- a/src/autodiff/Expression.cpp +++ b/src/autodiff/Expression.cpp @@ -85,7 +85,10 @@ void Expression::updateValue() { // Expression represents either a constant or an independent variable, so no update is needed return; } - updateChildren(); + + if (children.empty() && (lhs != nullptr || rhs != nullptr)) { + updateChildren(); + } for (const auto expr : std::ranges::reverse_view(children)) { if (expr->isUnary()) { diff --git a/src/autodiff/Variable.cpp b/src/autodiff/Variable.cpp index 95116c0..2210cd8 100644 --- a/src/autodiff/Variable.cpp +++ b/src/autodiff/Variable.cpp @@ -16,6 +16,10 @@ Variable Variable::Constant(double value) { return std::make_shared(value); } +void Variable::updateGraph() const { + expr->updateChildren(); +} + void Variable::updateValue() const { expr->updateValue(); } diff --git a/src/autodiff/derivatives.cpp b/src/autodiff/derivatives.cpp index c64b353..bbe00a3 100644 --- a/src/autodiff/derivatives.cpp +++ b/src/autodiff/derivatives.cpp @@ -10,6 +10,8 @@ namespace suboptimal { Gradient::Gradient(const Variable& var, const Eigen::Ref& wrt) : var{var}, wrt{wrt}, value(wrt.size()) { + var.updateGraph(); + std::ranges::for_each(wrt, [](const Variable& v) { v.expr->adjoint = 0.0; }); var.expr->updateAdjoints(); for (int i = 0; i < wrt.size(); i++) { @@ -39,8 +41,6 @@ VectorXv Gradient::getExpr() { } } - var.expr->updateChildren(); - // Initialize adjoint expressions std::ranges::for_each(var.expr->children, [](Expression* expr) { expr->adjoint_expr = std::make_shared(0.0); });