Skip to content

Commit

Permalink
feat: separate getting derivative values and exprs, cache <=linear de…
Browse files Browse the repository at this point in the history
…rivatives
  • Loading branch information
mimizh2418 committed Nov 30, 2024
1 parent df6670e commit 885f601
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 99 deletions.
2 changes: 1 addition & 1 deletion examples/derivatives/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
int main() {
suboptimal::Variable x{};
const suboptimal::Variable y = 1.0 / (1.0 + suboptimal::exp(-x)); // Sigmoid function
const suboptimal::Variable dydx = suboptimal::derivative(y, x);
suboptimal::Derivative dydx{y, x};

constexpr double min = -1.0;
constexpr double max = 1.0;
Expand Down
7 changes: 6 additions & 1 deletion include/suboptimal/autodiff/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct Expression {

ValueFunc value_func = nullptr; // Function giving the value of the expression

AdjointValueFunc lhs_adjoint_value = nullptr; // Function giving the adjoint value of the LHS expression
AdjointValueFunc lhs_adjoint_value_func = nullptr; // Function giving the adjoint value of the LHS expression
AdjointValueFunc rhs_adjoint_value_func = nullptr; // Function giving the adjoint value of the RHS expression

AdjointExprFunc lhs_adjoint_expr_func = nullptr; // Function giving the adjoint expression of the LHS expression
Expand Down Expand Up @@ -98,6 +98,11 @@ struct Expression {
* depends on
*/
void updateValue();

/**
* Updates the adjoint values of this expression graph
*/
void updateAdjoints();
};

// Arithmetic operator overloads
Expand Down
132 changes: 110 additions & 22 deletions include/suboptimal/autodiff/derivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,130 @@

#pragma once

#include <vector>

#include <Eigen/Core>

#include "suboptimal/autodiff/Variable.h"

namespace suboptimal {
/**
* Computes the gradient of a variable with respect to a set of variables
* @param var the variable to compute the gradient of
* @param wrt the variables to compute the gradient with respect to
* @return a vector of variables representing the gradient of var with respect to wrt
* Class for computing and storing the gradient of a variable. Caches results of linear and constant expressions and
* only recomputes quadratic and nonlinear expressions.
*/
VectorXv gradient(const Variable& var, const Eigen::Ref<const VectorXv>& wrt);
class Gradient {
public:
/**
* Constructs a gradient object
* @param var the variable to compute the gradient of
* @param wrt the vector of variables to compute the gradient with respect to
*/
Gradient(const Variable& var, const Eigen::Ref<const VectorXv>& wrt);

/**
* Gets the value of the gradient based on the current value of wrt
*/
const Eigen::SparseVector<double>& getValue();

/**
* Gets a vector of variables representing the gradient algebraically
*/
VectorXv getExpr();

private:
Variable var;
VectorXv wrt;
Eigen::SparseVector<double> value;

friend class Jacobian;
};

/**
* Computes the derivative of a variable with respect to another variable
* @param var the variable to compute the derivative of
* @param wrt the variable to compute the derivative with respect to
* @return a variable representing the derivative of var with respect to wrt
* Class for computing and storing the derivative of a variable with respect to another variable. Caches results of
* linear and constant expressions and only recomputes quadratic and nonlinear expressions.
*/
Variable derivative(const Variable& var, const Variable& wrt);
class Derivative {
public:
/**
* Constructs a Derivative object
* @param var the variable to compute the derivative of
* @param wrt the variable to compute the derivative with respect to
*/
Derivative(const Variable& var, const Variable& wrt);

/**
* Gets the value of the derivative based on the current value of wrt
*/
double getValue();

/**
* Gets a variable representing the derivative algebraically
*/
Variable getExpr();

private:
Variable var;
Variable wrt;
Gradient gradient;
};

/**
* Computes the Jacobian of a vector of variables with respect to another vector of variables
* @param vars the vector variables to compute the Jacobian of
* @param wrt the vector of variables to compute the Jacobian with respect to
* @return an n x m matrix of variables representing the Jacobian of vars with respect to wrt, where n is the length of
* vars and m is the length of wrt
* Class for computing and storing the Jacobian of a vector of variables. Caches results of linear and constant
* expressions and only recomputes quadratic and nonlinear expressions.
*/
MatrixXv jacobian(const Eigen::Ref<const VectorXv>& vars, const Eigen::Ref<const VectorXv>& wrt);
class Jacobian {
public:
/**
* Constructs a Jacobian object
* @param vars the variable to compute the gradient of
* @param wrt the vector of variables to compute the gradient with respect to
*/
Jacobian(const Eigen::Ref<const VectorXv>& vars, const Eigen::Ref<const VectorXv>& wrt);

/**
* Gets the value of the Jacobian based on the current value of wrt
*/
const Eigen::SparseMatrix<double>& getValue();

/**
* Gets a matrix of variables representing the Jacobian algebraically
*/
MatrixXv getExpr();

private:
VectorXv vars;
VectorXv wrt;
Eigen::SparseMatrix<double> value;

std::vector<Gradient> gradients;
std::vector<int> nonlinear_rows;
std::vector<Eigen::Triplet<double>> cache;
};

/**
* Computes the Hessian of a variable with respect to a set of variables
* @param var the variable to compute the Hessian of
* @param wrt the variables to compute the Hessian with respect to
* @return an n x n matrix of variables representing the Hessian of var with respect to wrt, where n is the length of
* wrt
* Class for computing and storing the Hessian of a variable. Caches results of linear and constant expressions and
* only recomputes quadratic and nonlinear expressions.
*/
MatrixXv hessian(const Variable& var, const Eigen::Ref<const VectorXv>& wrt);
class Hessian {
public:
/**
* Constructs a Hessian object
* @param var the variable to compute the Hessian of
* @param wrt the vector of variables to compute the Hessian with respect to
*/
Hessian(const Variable& var, const Eigen::Ref<const VectorXv>& wrt);

/**
* Gets the value of the Hessian based on the current value of wrt
*/
const Eigen::SparseMatrix<double>& getValue();

/**
* Gets a matrix of variables representing the Hessian algebraically
*/
MatrixXv getExpr();

private:
Jacobian jacobian;
};
} // namespace suboptimal
33 changes: 29 additions & 4 deletions src/autodiff/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Expression::Expression(const Linearity linearity, const ValueFunc value_func, co
: value{value_func(arg->value, 0.0)},
lhs{arg},
value_func{value_func},
lhs_adjoint_value{adjoint_value_func},
lhs_adjoint_value_func{adjoint_value_func},
lhs_adjoint_expr_func{adjoint_expr_func},
linearity{linearity} {}

Expand All @@ -29,7 +29,7 @@ Expression::Expression(const Linearity linearity, const ValueFunc valueFunc,
lhs{lhs},
rhs{rhs},
value_func{valueFunc},
lhs_adjoint_value{lhs_adjoint_value_func},
lhs_adjoint_value_func{lhs_adjoint_value_func},
rhs_adjoint_value_func{rhs_adjoint_value_func},
lhs_adjoint_expr_func{lhs_adjoint_expr_func},
rhs_adjoint_expr_func{rhs_adjoint_expr_func},
Expand Down Expand Up @@ -96,6 +96,29 @@ void Expression::updateValue() {
}
}

void Expression::updateAdjoints() {
if (isConstant()) {
// Constants always have an adjoint of 0
adjoint = 0.0;
return;
}
updateValue();

// Initialize adjoints
std::ranges::for_each(children, [](Expression* expr) { expr->adjoint = 0.0; });
adjoint = 1.0;

for (const auto expr : children) {
if (expr->lhs != nullptr && !expr->lhs->isConstant()) {
const double expr_rhs = expr->rhs != nullptr ? expr->rhs->value : 0.0;
expr->lhs->adjoint += expr->lhs_adjoint_value_func(expr->lhs->value, expr_rhs, expr->adjoint);
}
if (expr->rhs != nullptr && !expr->rhs->isConstant()) {
expr->rhs->adjoint += expr->rhs_adjoint_value_func(expr->lhs->value, expr->rhs->value, expr->adjoint);
}
}
}

// operator overloading boilerplate :skull:
// TODO operator null checks

Expand Down Expand Up @@ -127,7 +150,8 @@ ExpressionPtr operator+(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
}

return std::make_shared<Expression>(
std::max(lhs->linearity, rhs->linearity), [](const double lhs_val, const double rhs_val) { return lhs_val + rhs_val; },
std::max(lhs->linearity, rhs->linearity),
[](const double lhs_val, const double rhs_val) { return lhs_val + rhs_val; },
[](double, double, const double parent_adjoint) { return parent_adjoint; },
[](double, double, const double parent_adjoint) { return parent_adjoint; },
[](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint; },
Expand All @@ -147,7 +171,8 @@ ExpressionPtr operator-(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
}

return std::make_shared<Expression>(
std::max(lhs->linearity, rhs->linearity), [](const double lhs_val, const double rhs_val) { return lhs_val - rhs_val; },
std::max(lhs->linearity, rhs->linearity),
[](const double lhs_val, const double rhs_val) { return lhs_val - rhs_val; },
[](double, double, const double parent_adjoint) { return parent_adjoint; },
[](double, double, const double parent_adjoint) { return -parent_adjoint; },
[](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint; },
Expand Down
Loading

0 comments on commit 885f601

Please sign in to comment.