From 717accea21398085ed5526bdc4d72a4c30cde0bf Mon Sep 17 00:00:00 2001 From: Alvin Zhang <41vin2h4n9@gmail.com> Date: Mon, 2 Dec 2024 11:58:14 -0800 Subject: [PATCH] feat: add constant factory to Variable and simplify Variable construction --- include/suboptimal/autodiff/Variable.h | 48 +++++++++++++++----------- src/autodiff/Variable.cpp | 26 ++++++++------ src/autodiff/derivatives.cpp | 4 +-- 3 files changed, 44 insertions(+), 34 deletions(-) diff --git a/include/suboptimal/autodiff/Variable.h b/include/suboptimal/autodiff/Variable.h index 9c1f295..9b9369c 100644 --- a/include/suboptimal/autodiff/Variable.h +++ b/include/suboptimal/autodiff/Variable.h @@ -60,6 +60,12 @@ struct Variable { */ Variable(const ExpressionPtr& expr); // NOLINT + /** + * Creates a constant variable + * @param value the value of the constant + */ + static Variable Constant(double value); + /** * Updates the value of the variable, traversing the expression tree and updating all expressions and variables this * variable depends on @@ -129,11 +135,11 @@ template requires std::same_as || std::same_as Variable operator+(const LHS& lhs, const RHS& rhs) { if constexpr (std::is_arithmetic_v) { - return {std::make_shared(lhs) + rhs.expr}; + return std::make_shared(lhs) + rhs.expr; } else if constexpr (std::is_arithmetic_v) { - return {lhs.expr + std::make_shared(rhs)}; + return lhs.expr + std::make_shared(rhs); } else { - return {lhs.expr + rhs.expr}; + return lhs.expr + rhs.expr; } } @@ -141,11 +147,11 @@ template requires std::same_as || std::same_as Variable operator-(const LHS& lhs, const RHS& rhs) { if constexpr (std::is_arithmetic_v) { - return {std::make_shared(lhs) - rhs.expr}; + return std::make_shared(lhs) - rhs.expr; } else if constexpr (std::is_arithmetic_v) { - return {lhs.expr - std::make_shared(rhs)}; + return lhs.expr - std::make_shared(rhs); } else { - return {lhs.expr - rhs.expr}; + return lhs.expr - rhs.expr; } } @@ -153,11 +159,11 @@ template requires std::same_as || std::same_as Variable operator*(const LHS& lhs, const RHS& rhs) { if constexpr (std::is_arithmetic_v) { - return {std::make_shared(lhs) * rhs.expr}; + return std::make_shared(lhs) * rhs.expr; } else if constexpr (std::is_arithmetic_v) { - return {lhs.expr * std::make_shared(rhs)}; + return lhs.expr * std::make_shared(rhs); } else { - return {lhs.expr * rhs.expr}; + return lhs.expr * rhs.expr; } } @@ -165,11 +171,11 @@ template requires std::same_as || std::same_as Variable operator/(const LHS& lhs, const RHS& rhs) { if constexpr (std::is_arithmetic_v) { - return {std::make_shared(lhs) / rhs.expr}; + return std::make_shared(lhs) / rhs.expr; } else if constexpr (std::is_arithmetic_v) { - return {lhs.expr / std::make_shared(rhs)}; + return lhs.expr / std::make_shared(rhs); } else { - return {lhs.expr / rhs.expr}; + return lhs.expr / rhs.expr; } } @@ -182,11 +188,11 @@ template requires std::same_as || std::same_as Variable pow(const Base& base, const Exp& exponent) { if constexpr (std::is_arithmetic_v) { - return {pow(std::make_shared(base), exponent.expr)}; + return pow(std::make_shared(base), exponent.expr); } else if constexpr (std::is_arithmetic_v) { - return {pow(base.expr, std::make_shared(exponent))}; + return pow(base.expr, std::make_shared(exponent)); } else { - return {pow(base.expr, exponent.expr)}; + return pow(base.expr, exponent.expr); } } @@ -194,11 +200,11 @@ template requires std::same_as || std::same_as Variable hypot(const X& x, const Y& y) { if constexpr (std::is_arithmetic_v) { - return {hypot(std::make_shared(x), y.expr)}; + return hypot(std::make_shared(x), y.expr); } else if constexpr (std::is_arithmetic_v) { - return {hypot(x.expr, std::make_shared(y))}; + return hypot(x.expr, std::make_shared(y)); } else { - return {hypot(x.expr, y.expr)}; + return hypot(x.expr, y.expr); } } @@ -213,11 +219,11 @@ template requires std::same_as || std::same_as Variable atan2(const Y& y, const X& x) { if constexpr (std::is_arithmetic_v) { - return {atan2(std::make_shared(y), x.expr)}; + return atan2(std::make_shared(y), x.expr); } else if constexpr (std::is_arithmetic_v) { - return {atan2(y.expr, std::make_shared(x))}; + return atan2(y.expr, std::make_shared(x)); } else { - return {atan2(y.expr, x.expr)}; + return atan2(y.expr, x.expr); } } diff --git a/src/autodiff/Variable.cpp b/src/autodiff/Variable.cpp index ef66ba0..95116c0 100644 --- a/src/autodiff/Variable.cpp +++ b/src/autodiff/Variable.cpp @@ -12,6 +12,10 @@ Variable::Variable(const ExpressionPtr& expr) : expr{expr} {} Variable::Variable(ExpressionPtr&& expr) : expr{std::move(expr)} {} +Variable Variable::Constant(double value) { + return std::make_shared(value); +} + void Variable::updateValue() const { expr->updateValue(); } @@ -34,45 +38,45 @@ Variable operator+(const Variable& x) { } Variable operator-(const Variable& x) { - return {-x.expr}; + return -x.expr; } Variable abs(const Variable& x) { - return {abs(x.expr)}; + return abs(x.expr); } Variable sqrt(const Variable& x) { - return {sqrt(x.expr)}; + return sqrt(x.expr); } Variable exp(const Variable& x) { - return {exp(x.expr)}; + return exp(x.expr); } Variable log(const Variable& x) { - return {log(x.expr)}; + return log(x.expr); } Variable sin(const Variable& x) { - return {sin(x.expr)}; + return sin(x.expr); } Variable cos(const Variable& x) { - return {cos(x.expr)}; + return cos(x.expr); } Variable tan(const Variable& x) { - return {tan(x.expr)}; + return tan(x.expr); } Variable asin(const Variable& x) { - return {asin(x.expr)}; + return asin(x.expr); } Variable acos(const Variable& x) { - return {acos(x.expr)}; + return acos(x.expr); } Variable atan(const Variable& x) { - return {atan(x.expr)}; + return atan(x.expr); } } // namespace suboptimal diff --git a/src/autodiff/derivatives.cpp b/src/autodiff/derivatives.cpp index ad67f5c..c64b353 100644 --- a/src/autodiff/derivatives.cpp +++ b/src/autodiff/derivatives.cpp @@ -35,7 +35,7 @@ VectorXv Gradient::getExpr() { VectorXv grad{wrt.size()}; std::ranges::for_each(grad, [](Variable& v) { v.expr = std::make_shared(0.0); }); for (Eigen::SparseVector::InnerIterator it(value); it; ++it) { - grad(it.index()) = Variable{std::make_shared(it.value())}; + grad(it.index()) = Variable::Constant(it.value()); } } @@ -64,7 +64,7 @@ VectorXv Gradient::getExpr() { grad(i) = Variable{wrt(i).expr->adjoint_expr}; } else { // var is not dependent on wrt(i) - grad(i) = Variable{std::make_shared(0.0)}; + grad(i) = Variable::Constant(0.0); } }