Skip to content

Commit

Permalink
fix: stop constantly updating expression graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
mimizh2418 committed Dec 2, 2024
1 parent 717acce commit fe53416
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 3 deletions.
5 changes: 5 additions & 0 deletions include/suboptimal/autodiff/Variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ struct Variable {
*/
static Variable Constant(double value);

/**
* Updates the stored expression graph of the variable
*/
void updateGraph() const;

/**
* Updates the value of the variable, traversing the expression tree and updating all expressions and variables this
* variable depends on
Expand Down
5 changes: 4 additions & 1 deletion src/autodiff/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ void Expression::updateValue() {
// Expression represents either a constant or an independent variable, so no update is needed
return;
}
updateChildren();

if (children.empty() && (lhs != nullptr || rhs != nullptr)) {
updateChildren();
}

for (const auto expr : std::ranges::reverse_view(children)) {
if (expr->isUnary()) {
Expand Down
4 changes: 4 additions & 0 deletions src/autodiff/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Variable Variable::Constant(double value) {
return std::make_shared<Expression>(value);
}

void Variable::updateGraph() const {
expr->updateChildren();
}

void Variable::updateValue() const {
expr->updateValue();
}
Expand Down
4 changes: 2 additions & 2 deletions src/autodiff/derivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace suboptimal {
Gradient::Gradient(const Variable& var, const Eigen::Ref<const VectorXv>& wrt) : var{var}, wrt{wrt}, value(wrt.size()) {
var.updateGraph();

std::ranges::for_each(wrt, [](const Variable& v) { v.expr->adjoint = 0.0; });
var.expr->updateAdjoints();
for (int i = 0; i < wrt.size(); i++) {
Expand Down Expand Up @@ -39,8 +41,6 @@ VectorXv Gradient::getExpr() {
}
}

var.expr->updateChildren();

// Initialize adjoint expressions
std::ranges::for_each(var.expr->children,
[](Expression* expr) { expr->adjoint_expr = std::make_shared<Expression>(0.0); });
Expand Down

0 comments on commit fe53416

Please sign in to comment.