Skip to content

Commit

Permalink
refactor: rename ExpressionType
Browse files Browse the repository at this point in the history
  • Loading branch information
mimizh2418 committed Nov 30, 2024
1 parent 8229359 commit df6670e
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 48 deletions.
12 changes: 6 additions & 6 deletions include/suboptimal/autodiff/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <memory>
#include <vector>

#include "suboptimal/autodiff/ExpressionType.h"
#include "suboptimal/autodiff/Linearity.h"

namespace suboptimal {
struct Expression;
Expand Down Expand Up @@ -37,7 +37,7 @@ struct Expression {
AdjointExprFunc lhs_adjoint_expr_func = nullptr; // Function giving the adjoint expression of the LHS expression
AdjointExprFunc rhs_adjoint_expr_func = nullptr; // Function giving the adjoint expression of the RHS expression

ExpressionType type = ExpressionType::Constant;
Linearity linearity = Linearity::Constant;

int wrt_index = -1;

Expand All @@ -48,18 +48,18 @@ struct Expression {
/**
* Constructs a nullary expression
*/
explicit Expression(double value, ExpressionType type = ExpressionType::Constant);
explicit Expression(double value, Linearity linearity = Linearity::Constant);

/**
* Constructs a unary expression
*/
Expression(ExpressionType type, ValueFunc value_func, AdjointValueFunc adjoint_value_func,
Expression(Linearity linearity, ValueFunc value_func, AdjointValueFunc adjoint_value_func,
AdjointExprFunc adjoint_expr_func, ExpressionPtr arg);

/**
* Constructs a binary expression
*/
Expression(ExpressionType type, ValueFunc valueFunc, AdjointValueFunc lhs_adjoint_value_func,
Expression(Linearity linearity, ValueFunc valueFunc, AdjointValueFunc lhs_adjoint_value_func,
AdjointValueFunc rhs_adjoint_value_func, AdjointExprFunc lhs_adjoint_expr_func,
AdjointExprFunc rhs_adjoint_expr_func, ExpressionPtr lhs, ExpressionPtr rhs);

Expand All @@ -71,7 +71,7 @@ struct Expression {
/**
* Checks if the value of this expression is constant
*/
bool isConstant() const { return isIndependent() && type == ExpressionType::Constant; }
bool isConstant() const { return isIndependent() && linearity == Linearity::Constant; }

/**
* Checks if the expression represents a unary operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <string>

namespace suboptimal {
enum class ExpressionType : int {
enum class Linearity : int {
// The expression is a constant
Constant,
// The expression contains linear and constant terms
Expand All @@ -16,9 +16,9 @@ enum class ExpressionType : int {
Nonlinear
};

constexpr std::string toString(const ExpressionType& exprType) {
using enum ExpressionType;
switch (exprType) {
constexpr std::string toString(const Linearity& linearity) {
using enum Linearity;
switch (linearity) {
case Constant:
return "constant";
case Linear:
Expand Down
5 changes: 3 additions & 2 deletions include/suboptimal/autodiff/Variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <Eigen/SparseCore>

#include "suboptimal/autodiff/Expression.h"
#include "suboptimal/autodiff/Linearity.h"

namespace suboptimal {
struct Variable;
Expand All @@ -31,7 +32,7 @@ using Matrix4v = Eigen::Matrix4<Variable>;
* An autodiff variable. Essentially just a nicer wrapper around Expression
*/
struct Variable {
ExpressionPtr expr = std::make_shared<Expression>(0.0, ExpressionType::Linear);
ExpressionPtr expr = std::make_shared<Expression>(0.0, Linearity::Linear);

/**
* Constructs an independent variable with initial value 0
Expand Down Expand Up @@ -79,7 +80,7 @@ struct Variable {
* Gets the degree of the expression this variable represents
* @return the type of the expression
*/
ExpressionType getType() const { return expr->type; }
Linearity getLinearity() const { return expr->linearity; }

/**
* Checks if the expression this variable represents is independent of other expressions
Expand Down
66 changes: 33 additions & 33 deletions src/autodiff/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
#include <vector>

namespace suboptimal {
Expression::Expression(const double value, const ExpressionType type) : value{value}, type{type} {}
Expression::Expression(const double value, const Linearity linearity) : value{value}, linearity{linearity} {}

Expression::Expression(const ExpressionType type, const ValueFunc value_func, const AdjointValueFunc adjoint_value_func,
Expression::Expression(const Linearity linearity, const ValueFunc value_func, const AdjointValueFunc adjoint_value_func,
const AdjointExprFunc adjoint_expr_func, const ExpressionPtr arg) // NOLINT
: value{value_func(arg->value, 0.0)},
lhs{arg},
value_func{value_func},
lhs_adjoint_value{adjoint_value_func},
lhs_adjoint_expr_func{adjoint_expr_func},
type{type} {}
linearity{linearity} {}

Expression::Expression(const ExpressionType type, const ValueFunc valueFunc,
Expression::Expression(const Linearity linearity, const ValueFunc valueFunc,
const AdjointValueFunc lhs_adjoint_value_func, const AdjointValueFunc rhs_adjoint_value_func,
const AdjointExprFunc lhs_adjoint_expr_func, const AdjointExprFunc rhs_adjoint_expr_func,
const ExpressionPtr lhs, // NOLINT
Expand All @@ -33,7 +33,7 @@ Expression::Expression(const ExpressionType type, const ValueFunc valueFunc,
rhs_adjoint_value_func{rhs_adjoint_value_func},
lhs_adjoint_expr_func{lhs_adjoint_expr_func},
rhs_adjoint_expr_func{rhs_adjoint_expr_func},
type{type} {}
linearity{linearity} {}

void Expression::updateChildren() {
children.clear();
Expand Down Expand Up @@ -109,7 +109,7 @@ ExpressionPtr operator-(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
x->type, [](const double val, double) { return -val; },
x->linearity, [](const double val, double) { return -val; },
[](double, double, const double parent_adjoint) { return -parent_adjoint; },
[](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return -parent_adjoint; },
x);
Expand All @@ -127,7 +127,7 @@ ExpressionPtr operator+(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
}

return std::make_shared<Expression>(
std::max(lhs->type, rhs->type), [](const double lhs_val, const double rhs_val) { return lhs_val + rhs_val; },
std::max(lhs->linearity, rhs->linearity), [](const double lhs_val, const double rhs_val) { return lhs_val + rhs_val; },
[](double, double, const double parent_adjoint) { return parent_adjoint; },
[](double, double, const double parent_adjoint) { return parent_adjoint; },
[](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint; },
Expand All @@ -147,7 +147,7 @@ ExpressionPtr operator-(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
}

return std::make_shared<Expression>(
std::max(lhs->type, rhs->type), [](const double lhs_val, const double rhs_val) { return lhs_val - rhs_val; },
std::max(lhs->linearity, rhs->linearity), [](const double lhs_val, const double rhs_val) { return lhs_val - rhs_val; },
[](double, double, const double parent_adjoint) { return parent_adjoint; },
[](double, double, const double parent_adjoint) { return -parent_adjoint; },
[](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint; },
Expand All @@ -169,19 +169,19 @@ ExpressionPtr operator*(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
return std::make_shared<Expression>(lhs->value * rhs->value);
}

ExpressionType type;
Linearity linearity;
if (lhs->isConstant()) {
type = rhs->type;
linearity = rhs->linearity;
} else if (rhs->isConstant()) {
type = lhs->type;
} else if (lhs->type == ExpressionType::Linear && rhs->type == ExpressionType::Linear) {
type = ExpressionType::Quadratic;
linearity = lhs->linearity;
} else if (lhs->linearity == Linearity::Linear && rhs->linearity == Linearity::Linear) {
linearity = Linearity::Quadratic;
} else {
type = ExpressionType::Nonlinear;
linearity = Linearity::Nonlinear;
}

return std::make_shared<Expression>(
type, [](const double lhs_val, const double rhs_val) { return lhs_val * rhs_val; },
linearity, [](const double lhs_val, const double rhs_val) { return lhs_val * rhs_val; },
[](double, const double rhs_val, const double parent_adjoint) { return rhs_val * parent_adjoint; },
[](const double lhs_val, double, const double parent_adjoint) { return lhs_val * parent_adjoint; },
[](const ExpressionPtr&, const ExpressionPtr& rhs_expr, const ExpressionPtr& parent_adjoint) {
Expand Down Expand Up @@ -209,7 +209,7 @@ ExpressionPtr operator/(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
}

return std::make_shared<Expression>(
rhs->isConstant() ? lhs->type : ExpressionType::Nonlinear,
rhs->isConstant() ? lhs->linearity : Linearity::Nonlinear,
[](const double lhs_val, const double rhs_val) { return lhs_val / rhs_val; },
[](double, const double rhs_val, const double parent_adjoint) { return parent_adjoint / rhs_val; },
[](const double lhs_val, const double rhs_val, const double parent_adjoint) {
Expand All @@ -233,7 +233,7 @@ ExpressionPtr abs(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::abs(val); },
Linearity::Nonlinear, [](const double val, double) { return std::abs(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint * val / std::abs(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * expr / suboptimal::abs(expr);
Expand All @@ -250,7 +250,7 @@ ExpressionPtr sqrt(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::sqrt(val); },
Linearity::Nonlinear, [](const double val, double) { return std::sqrt(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint * 0.5 / std::sqrt(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * std::make_shared<Expression>(0.5) / suboptimal::sqrt(expr);
Expand All @@ -267,7 +267,7 @@ ExpressionPtr exp(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::exp(val); },
Linearity::Nonlinear, [](const double val, double) { return std::exp(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint * std::exp(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * suboptimal::exp(expr);
Expand All @@ -284,7 +284,7 @@ ExpressionPtr log(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::log(val); },
Linearity::Nonlinear, [](const double val, double) { return std::log(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint / val; },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint / expr;
Expand All @@ -306,15 +306,15 @@ ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& exponent) {
return std::make_shared<Expression>(std::pow(base->value, exponent->value));
}

ExpressionType type;
if (base->type == ExpressionType::Linear && exponent->constEquals(2.0)) {
type = ExpressionType::Quadratic;
Linearity linearity;
if (base->linearity == Linearity::Linear && exponent->constEquals(2.0)) {
linearity = Linearity::Quadratic;
} else {
type = ExpressionType::Nonlinear;
linearity = Linearity::Nonlinear;
}

return std::make_shared<Expression>(
type, [](const double base_val, const double exp_val) { return std::pow(base_val, exp_val); },
linearity, [](const double base_val, const double exp_val) { return std::pow(base_val, exp_val); },
[](const double base_val, const double exp_val, const double parent_adjoint) {
return parent_adjoint * exp_val * std::pow(base_val, exp_val - 1);
},
Expand Down Expand Up @@ -342,7 +342,7 @@ ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double x_val, const double y_val) { return std::hypot(x_val, y_val); },
Linearity::Nonlinear, [](const double x_val, const double y_val) { return std::hypot(x_val, y_val); },
[](const double x_val, const double y_val, const double parent_adjoint) {
return parent_adjoint * x_val / std::hypot(x_val, y_val);
},
Expand All @@ -367,7 +367,7 @@ ExpressionPtr sin(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::sin(val); },
Linearity::Nonlinear, [](const double val, double) { return std::sin(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint * std::cos(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * suboptimal::cos(expr);
Expand All @@ -381,7 +381,7 @@ ExpressionPtr cos(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::cos(val); },
Linearity::Nonlinear, [](const double val, double) { return std::cos(val); },
[](const double val, double, const double parent_adjoint) { return -parent_adjoint * std::sin(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return -parent_adjoint * suboptimal::sin(expr);
Expand All @@ -398,7 +398,7 @@ ExpressionPtr tan(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::tan(val); },
Linearity::Nonlinear, [](const double val, double) { return std::tan(val); },
[](const double val, double, const double parent_adjoint) {
return parent_adjoint / (std::cos(val) * std::cos(val));
},
Expand All @@ -417,7 +417,7 @@ ExpressionPtr asin(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::asin(val); },
Linearity::Nonlinear, [](const double val, double) { return std::asin(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint / std::sqrt(1 - val * val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint / suboptimal::sqrt(std::make_shared<Expression>(1.0) - expr * expr);
Expand All @@ -431,7 +431,7 @@ ExpressionPtr acos(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::acos(val); },
Linearity::Nonlinear, [](const double val, double) { return std::acos(val); },
[](const double val, double, const double parent_adjoint) { return -parent_adjoint / std::sqrt(1 - val * val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return -parent_adjoint / suboptimal::sqrt(std::make_shared<Expression>(1.0) - expr * expr);
Expand All @@ -448,7 +448,7 @@ ExpressionPtr atan(const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double val, double) { return std::atan(val); },
Linearity::Nonlinear, [](const double val, double) { return std::atan(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint / (1 + val * val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint / (std::make_shared<Expression>(1.0) + expr * expr);
Expand All @@ -465,7 +465,7 @@ ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) {
}

return std::make_shared<Expression>(
ExpressionType::Nonlinear, [](const double y_val, const double x_val) { return std::atan2(y_val, x_val); },
Linearity::Nonlinear, [](const double y_val, const double x_val) { return std::atan2(y_val, x_val); },
[](const double y_val, const double x_val, const double parent_adjoint) {
return parent_adjoint * x_val / (y_val * y_val + x_val * x_val);
},
Expand Down
2 changes: 1 addition & 1 deletion src/autodiff/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <utility>

namespace suboptimal {
Variable::Variable(double value) : expr{std::make_shared<Expression>(value, ExpressionType::Linear)} {}
Variable::Variable(double value) : expr{std::make_shared<Expression>(value, Linearity::Linear)} {}

Variable::Variable(const ExpressionPtr& expr) : expr{expr} {}

Expand Down
4 changes: 2 additions & 2 deletions test/autodiff/Variable_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ using namespace suboptimal;
TEST_CASE("Autodiff - Variable constructor", "[autodiff]") {
const Variable x{1.0};
CHECK(x.getValue() == 1.0);
CHECK(x.getType() == ExpressionType::Linear);
CHECK(x.getLinearity() == Linearity::Linear);

const Variable y{};
CHECK(y.getValue() == 0.0);
CHECK(y.getType() == ExpressionType::Linear);
CHECK(y.getLinearity() == Linearity::Linear);
}

TEST_CASE("Autodiff - Variable basic arithmetic", "[autodiff]") {
Expand Down

0 comments on commit df6670e

Please sign in to comment.