From df6670e85830868f5eee56ee357554090b11cf4b Mon Sep 17 00:00:00 2001 From: Alvin Zhang <41vin2h4n9@gmail.com> Date: Fri, 29 Nov 2024 17:51:18 -0800 Subject: [PATCH] refactor: rename ExpressionType --- include/suboptimal/autodiff/Expression.h | 12 ++-- .../{ExpressionType.h => Linearity.h} | 8 +-- include/suboptimal/autodiff/Variable.h | 5 +- src/autodiff/Expression.cpp | 66 +++++++++---------- src/autodiff/Variable.cpp | 2 +- test/autodiff/Variable_test.cpp | 4 +- 6 files changed, 49 insertions(+), 48 deletions(-) rename include/suboptimal/autodiff/{ExpressionType.h => Linearity.h} (79%) diff --git a/include/suboptimal/autodiff/Expression.h b/include/suboptimal/autodiff/Expression.h index b2c05c2..202229b 100644 --- a/include/suboptimal/autodiff/Expression.h +++ b/include/suboptimal/autodiff/Expression.h @@ -5,7 +5,7 @@ #include #include -#include "suboptimal/autodiff/ExpressionType.h" +#include "suboptimal/autodiff/Linearity.h" namespace suboptimal { struct Expression; @@ -37,7 +37,7 @@ struct Expression { AdjointExprFunc lhs_adjoint_expr_func = nullptr; // Function giving the adjoint expression of the LHS expression AdjointExprFunc rhs_adjoint_expr_func = nullptr; // Function giving the adjoint expression of the RHS expression - ExpressionType type = ExpressionType::Constant; + Linearity linearity = Linearity::Constant; int wrt_index = -1; @@ -48,18 +48,18 @@ struct Expression { /** * Constructs a nullary expression */ - explicit Expression(double value, ExpressionType type = ExpressionType::Constant); + explicit Expression(double value, Linearity linearity = Linearity::Constant); /** * Constructs a unary expression */ - Expression(ExpressionType type, ValueFunc value_func, AdjointValueFunc adjoint_value_func, + Expression(Linearity linearity, ValueFunc value_func, AdjointValueFunc adjoint_value_func, AdjointExprFunc adjoint_expr_func, ExpressionPtr arg); /** * Constructs a binary expression */ - Expression(ExpressionType type, ValueFunc valueFunc, AdjointValueFunc lhs_adjoint_value_func, + Expression(Linearity linearity, ValueFunc valueFunc, AdjointValueFunc lhs_adjoint_value_func, AdjointValueFunc rhs_adjoint_value_func, AdjointExprFunc lhs_adjoint_expr_func, AdjointExprFunc rhs_adjoint_expr_func, ExpressionPtr lhs, ExpressionPtr rhs); @@ -71,7 +71,7 @@ struct Expression { /** * Checks if the value of this expression is constant */ - bool isConstant() const { return isIndependent() && type == ExpressionType::Constant; } + bool isConstant() const { return isIndependent() && linearity == Linearity::Constant; } /** * Checks if the expression represents a unary operation diff --git a/include/suboptimal/autodiff/ExpressionType.h b/include/suboptimal/autodiff/Linearity.h similarity index 79% rename from include/suboptimal/autodiff/ExpressionType.h rename to include/suboptimal/autodiff/Linearity.h index e4a970d..3ee1df0 100644 --- a/include/suboptimal/autodiff/ExpressionType.h +++ b/include/suboptimal/autodiff/Linearity.h @@ -5,7 +5,7 @@ #include namespace suboptimal { -enum class ExpressionType : int { +enum class Linearity : int { // The expression is a constant Constant, // The expression contains linear and constant terms @@ -16,9 +16,9 @@ enum class ExpressionType : int { Nonlinear }; -constexpr std::string toString(const ExpressionType& exprType) { - using enum ExpressionType; - switch (exprType) { +constexpr std::string toString(const Linearity& linearity) { + using enum Linearity; + switch (linearity) { case Constant: return "constant"; case Linear: diff --git a/include/suboptimal/autodiff/Variable.h b/include/suboptimal/autodiff/Variable.h index d77c550..66d4907 100644 --- a/include/suboptimal/autodiff/Variable.h +++ b/include/suboptimal/autodiff/Variable.h @@ -9,6 +9,7 @@ #include #include "suboptimal/autodiff/Expression.h" +#include "suboptimal/autodiff/Linearity.h" namespace suboptimal { struct Variable; @@ -31,7 +32,7 @@ using Matrix4v = Eigen::Matrix4; * An autodiff variable. Essentially just a nicer wrapper around Expression */ struct Variable { - ExpressionPtr expr = std::make_shared(0.0, ExpressionType::Linear); + ExpressionPtr expr = std::make_shared(0.0, Linearity::Linear); /** * Constructs an independent variable with initial value 0 @@ -79,7 +80,7 @@ struct Variable { * Gets the degree of the expression this variable represents * @return the type of the expression */ - ExpressionType getType() const { return expr->type; } + Linearity getLinearity() const { return expr->linearity; } /** * Checks if the expression this variable represents is independent of other expressions diff --git a/src/autodiff/Expression.cpp b/src/autodiff/Expression.cpp index 9880798..fa48552 100644 --- a/src/autodiff/Expression.cpp +++ b/src/autodiff/Expression.cpp @@ -9,18 +9,18 @@ #include namespace suboptimal { -Expression::Expression(const double value, const ExpressionType type) : value{value}, type{type} {} +Expression::Expression(const double value, const Linearity linearity) : value{value}, linearity{linearity} {} -Expression::Expression(const ExpressionType type, const ValueFunc value_func, const AdjointValueFunc adjoint_value_func, +Expression::Expression(const Linearity linearity, const ValueFunc value_func, const AdjointValueFunc adjoint_value_func, const AdjointExprFunc adjoint_expr_func, const ExpressionPtr arg) // NOLINT : value{value_func(arg->value, 0.0)}, lhs{arg}, value_func{value_func}, lhs_adjoint_value{adjoint_value_func}, lhs_adjoint_expr_func{adjoint_expr_func}, - type{type} {} + linearity{linearity} {} -Expression::Expression(const ExpressionType type, const ValueFunc valueFunc, +Expression::Expression(const Linearity linearity, const ValueFunc valueFunc, const AdjointValueFunc lhs_adjoint_value_func, const AdjointValueFunc rhs_adjoint_value_func, const AdjointExprFunc lhs_adjoint_expr_func, const AdjointExprFunc rhs_adjoint_expr_func, const ExpressionPtr lhs, // NOLINT @@ -33,7 +33,7 @@ Expression::Expression(const ExpressionType type, const ValueFunc valueFunc, rhs_adjoint_value_func{rhs_adjoint_value_func}, lhs_adjoint_expr_func{lhs_adjoint_expr_func}, rhs_adjoint_expr_func{rhs_adjoint_expr_func}, - type{type} {} + linearity{linearity} {} void Expression::updateChildren() { children.clear(); @@ -109,7 +109,7 @@ ExpressionPtr operator-(const ExpressionPtr& x) { } return std::make_shared( - x->type, [](const double val, double) { return -val; }, + x->linearity, [](const double val, double) { return -val; }, [](double, double, const double parent_adjoint) { return -parent_adjoint; }, [](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return -parent_adjoint; }, x); @@ -127,7 +127,7 @@ ExpressionPtr operator+(const ExpressionPtr& lhs, const ExpressionPtr& rhs) { } return std::make_shared( - std::max(lhs->type, rhs->type), [](const double lhs_val, const double rhs_val) { return lhs_val + rhs_val; }, + std::max(lhs->linearity, rhs->linearity), [](const double lhs_val, const double rhs_val) { return lhs_val + rhs_val; }, [](double, double, const double parent_adjoint) { return parent_adjoint; }, [](double, double, const double parent_adjoint) { return parent_adjoint; }, [](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint; }, @@ -147,7 +147,7 @@ ExpressionPtr operator-(const ExpressionPtr& lhs, const ExpressionPtr& rhs) { } return std::make_shared( - std::max(lhs->type, rhs->type), [](const double lhs_val, const double rhs_val) { return lhs_val - rhs_val; }, + std::max(lhs->linearity, rhs->linearity), [](const double lhs_val, const double rhs_val) { return lhs_val - rhs_val; }, [](double, double, const double parent_adjoint) { return parent_adjoint; }, [](double, double, const double parent_adjoint) { return -parent_adjoint; }, [](const ExpressionPtr&, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint; }, @@ -169,19 +169,19 @@ ExpressionPtr operator*(const ExpressionPtr& lhs, const ExpressionPtr& rhs) { return std::make_shared(lhs->value * rhs->value); } - ExpressionType type; + Linearity linearity; if (lhs->isConstant()) { - type = rhs->type; + linearity = rhs->linearity; } else if (rhs->isConstant()) { - type = lhs->type; - } else if (lhs->type == ExpressionType::Linear && rhs->type == ExpressionType::Linear) { - type = ExpressionType::Quadratic; + linearity = lhs->linearity; + } else if (lhs->linearity == Linearity::Linear && rhs->linearity == Linearity::Linear) { + linearity = Linearity::Quadratic; } else { - type = ExpressionType::Nonlinear; + linearity = Linearity::Nonlinear; } return std::make_shared( - type, [](const double lhs_val, const double rhs_val) { return lhs_val * rhs_val; }, + linearity, [](const double lhs_val, const double rhs_val) { return lhs_val * rhs_val; }, [](double, const double rhs_val, const double parent_adjoint) { return rhs_val * parent_adjoint; }, [](const double lhs_val, double, const double parent_adjoint) { return lhs_val * parent_adjoint; }, [](const ExpressionPtr&, const ExpressionPtr& rhs_expr, const ExpressionPtr& parent_adjoint) { @@ -209,7 +209,7 @@ ExpressionPtr operator/(const ExpressionPtr& lhs, const ExpressionPtr& rhs) { } return std::make_shared( - rhs->isConstant() ? lhs->type : ExpressionType::Nonlinear, + rhs->isConstant() ? lhs->linearity : Linearity::Nonlinear, [](const double lhs_val, const double rhs_val) { return lhs_val / rhs_val; }, [](double, const double rhs_val, const double parent_adjoint) { return parent_adjoint / rhs_val; }, [](const double lhs_val, const double rhs_val, const double parent_adjoint) { @@ -233,7 +233,7 @@ ExpressionPtr abs(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::abs(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::abs(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint * val / std::abs(val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint * expr / suboptimal::abs(expr); @@ -250,7 +250,7 @@ ExpressionPtr sqrt(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::sqrt(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::sqrt(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint * 0.5 / std::sqrt(val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint * std::make_shared(0.5) / suboptimal::sqrt(expr); @@ -267,7 +267,7 @@ ExpressionPtr exp(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::exp(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::exp(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint * std::exp(val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint * suboptimal::exp(expr); @@ -284,7 +284,7 @@ ExpressionPtr log(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::log(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::log(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint / val; }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint / expr; @@ -306,15 +306,15 @@ ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& exponent) { return std::make_shared(std::pow(base->value, exponent->value)); } - ExpressionType type; - if (base->type == ExpressionType::Linear && exponent->constEquals(2.0)) { - type = ExpressionType::Quadratic; + Linearity linearity; + if (base->linearity == Linearity::Linear && exponent->constEquals(2.0)) { + linearity = Linearity::Quadratic; } else { - type = ExpressionType::Nonlinear; + linearity = Linearity::Nonlinear; } return std::make_shared( - type, [](const double base_val, const double exp_val) { return std::pow(base_val, exp_val); }, + linearity, [](const double base_val, const double exp_val) { return std::pow(base_val, exp_val); }, [](const double base_val, const double exp_val, const double parent_adjoint) { return parent_adjoint * exp_val * std::pow(base_val, exp_val - 1); }, @@ -342,7 +342,7 @@ ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double x_val, const double y_val) { return std::hypot(x_val, y_val); }, + Linearity::Nonlinear, [](const double x_val, const double y_val) { return std::hypot(x_val, y_val); }, [](const double x_val, const double y_val, const double parent_adjoint) { return parent_adjoint * x_val / std::hypot(x_val, y_val); }, @@ -367,7 +367,7 @@ ExpressionPtr sin(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::sin(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::sin(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint * std::cos(val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint * suboptimal::cos(expr); @@ -381,7 +381,7 @@ ExpressionPtr cos(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::cos(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::cos(val); }, [](const double val, double, const double parent_adjoint) { return -parent_adjoint * std::sin(val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return -parent_adjoint * suboptimal::sin(expr); @@ -398,7 +398,7 @@ ExpressionPtr tan(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::tan(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::tan(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint / (std::cos(val) * std::cos(val)); }, @@ -417,7 +417,7 @@ ExpressionPtr asin(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::asin(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::asin(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint / std::sqrt(1 - val * val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint / suboptimal::sqrt(std::make_shared(1.0) - expr * expr); @@ -431,7 +431,7 @@ ExpressionPtr acos(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::acos(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::acos(val); }, [](const double val, double, const double parent_adjoint) { return -parent_adjoint / std::sqrt(1 - val * val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return -parent_adjoint / suboptimal::sqrt(std::make_shared(1.0) - expr * expr); @@ -448,7 +448,7 @@ ExpressionPtr atan(const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double val, double) { return std::atan(val); }, + Linearity::Nonlinear, [](const double val, double) { return std::atan(val); }, [](const double val, double, const double parent_adjoint) { return parent_adjoint / (1 + val * val); }, [](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) { return parent_adjoint / (std::make_shared(1.0) + expr * expr); @@ -465,7 +465,7 @@ ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) { } return std::make_shared( - ExpressionType::Nonlinear, [](const double y_val, const double x_val) { return std::atan2(y_val, x_val); }, + Linearity::Nonlinear, [](const double y_val, const double x_val) { return std::atan2(y_val, x_val); }, [](const double y_val, const double x_val, const double parent_adjoint) { return parent_adjoint * x_val / (y_val * y_val + x_val * x_val); }, diff --git a/src/autodiff/Variable.cpp b/src/autodiff/Variable.cpp index a3d644b..ef66ba0 100644 --- a/src/autodiff/Variable.cpp +++ b/src/autodiff/Variable.cpp @@ -6,7 +6,7 @@ #include namespace suboptimal { -Variable::Variable(double value) : expr{std::make_shared(value, ExpressionType::Linear)} {} +Variable::Variable(double value) : expr{std::make_shared(value, Linearity::Linear)} {} Variable::Variable(const ExpressionPtr& expr) : expr{expr} {} diff --git a/test/autodiff/Variable_test.cpp b/test/autodiff/Variable_test.cpp index 708fe77..8a5bee0 100644 --- a/test/autodiff/Variable_test.cpp +++ b/test/autodiff/Variable_test.cpp @@ -16,11 +16,11 @@ using namespace suboptimal; TEST_CASE("Autodiff - Variable constructor", "[autodiff]") { const Variable x{1.0}; CHECK(x.getValue() == 1.0); - CHECK(x.getType() == ExpressionType::Linear); + CHECK(x.getLinearity() == Linearity::Linear); const Variable y{}; CHECK(y.getValue() == 0.0); - CHECK(y.getType() == ExpressionType::Linear); + CHECK(y.getLinearity() == Linearity::Linear); } TEST_CASE("Autodiff - Variable basic arithmetic", "[autodiff]") {