Skip to content

Commit

Permalink
feat: add erf and hyperbolic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mimizh2418 committed Dec 4, 2024
1 parent 952cbb1 commit 23962a6
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 9 deletions.
11 changes: 11 additions & 0 deletions include/suboptimal/autodiff/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,22 @@ ExpressionPtr log(const ExpressionPtr& x);
ExpressionPtr pow(const ExpressionPtr& base, const ExpressionPtr& exponent);
ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y);

ExpressionPtr erf(const ExpressionPtr& x);

ExpressionPtr sin(const ExpressionPtr& x);
ExpressionPtr cos(const ExpressionPtr& x);
ExpressionPtr tan(const ExpressionPtr& x);

ExpressionPtr asin(const ExpressionPtr& x);
ExpressionPtr acos(const ExpressionPtr& x);
ExpressionPtr atan(const ExpressionPtr& x);
ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x);

ExpressionPtr sinh(const ExpressionPtr& x);
ExpressionPtr cosh(const ExpressionPtr& x);
ExpressionPtr tanh(const ExpressionPtr& x);

ExpressionPtr asinh(const ExpressionPtr& x);
ExpressionPtr acosh(const ExpressionPtr& x);
ExpressionPtr atanh(const ExpressionPtr& x);
} // namespace suboptimal
11 changes: 11 additions & 0 deletions include/suboptimal/autodiff/Variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,12 @@ Variable hypot(const X& x, const Y& y) {
}
}

Variable erf(const Variable& x);

Variable sin(const Variable& x);
Variable cos(const Variable& x);
Variable tan(const Variable& x);

Variable asin(const Variable& x);
Variable acos(const Variable& x);
Variable atan(const Variable& x);
Expand All @@ -232,6 +235,14 @@ Variable atan2(const Y& y, const X& x) {
}
}

Variable sinh(const Variable& x);
Variable cosh(const Variable& x);
Variable tanh(const Variable& x);

Variable asinh(const Variable& x);
Variable acosh(const Variable& x);
Variable atanh(const Variable& x);

/**
* Returns a matrix holding the values of the Variables in the input matrix
* @param var_mat the input matrix of Variables
Expand Down
121 changes: 121 additions & 0 deletions src/autodiff/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,26 @@ ExpressionPtr hypot(const ExpressionPtr& x, const ExpressionPtr& y) {
x, y);
}

ExpressionPtr erf(const ExpressionPtr& x) {
if (x->constEquals(0.0)) {
return x;
}
if (x->isConstant()) {
return std::make_shared<Expression>(std::erf(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::erf(val); },
[](const double val, double, const double parent_adjoint) {
return parent_adjoint * 2.0 * std::exp(-val * val) / std::sqrt(std::numbers::pi);
},
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * std::make_shared<Expression>(2.0) * suboptimal::exp(-expr * expr) /
std::make_shared<Expression>(std::sqrt(std::numbers::pi));
},
x);
}

ExpressionPtr sin(const ExpressionPtr& x) {
if (x->constEquals(0.0)) {
return x;
Expand Down Expand Up @@ -508,4 +528,105 @@ ExpressionPtr atan2(const ExpressionPtr& y, const ExpressionPtr& x) {
},
y, x);
}

ExpressionPtr sinh(const ExpressionPtr& x) {
if (x->constEquals(0.0)) {
return x;
}
if (x->isConstant()) {
return std::make_shared<Expression>(std::sinh(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::sinh(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint * std::cosh(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * suboptimal::cosh(expr);
},
x);
}

ExpressionPtr cosh(const ExpressionPtr& x) {
if (x->isConstant()) {
return std::make_shared<Expression>(std::cosh(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::cosh(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint * std::sinh(val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * suboptimal::sinh(expr);
},
x);
}

ExpressionPtr tanh(const ExpressionPtr& x) {
if (x->constEquals(0.0)) {
return x;
}
if (x->isConstant()) {
return std::make_shared<Expression>(std::tanh(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::tanh(val); },
[](const double val, double, const double parent_adjoint) {
return parent_adjoint * (1 - std::tanh(val) * std::tanh(val));
},
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint * (std::make_shared<Expression>(1.0) - suboptimal::tanh(expr) * suboptimal::tanh(expr));
},
x);
}

ExpressionPtr asinh(const ExpressionPtr& x) {
if (x->constEquals(0.0)) {
return x;
}
if (x->isConstant()) {
return std::make_shared<Expression>(std::asinh(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::asinh(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint / std::sqrt(1 + val * val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint / suboptimal::sqrt(std::make_shared<Expression>(1.0) + expr * expr);
},
x);
}

ExpressionPtr acosh(const ExpressionPtr& x) {
if (x->isConstant()) {
return std::make_shared<Expression>(std::acosh(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::acosh(val); },
[](const double val, double, const double parent_adjoint) {
return parent_adjoint / (std::sqrt(val - 1) * std::sqrt(val + 1));
},
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint / (suboptimal::sqrt(expr - std::make_shared<Expression>(1.0)) *
suboptimal::sqrt(expr + std::make_shared<Expression>(1.0)));
},
x);
}

ExpressionPtr atanh(const ExpressionPtr& x) {
if (x->constEquals(0.0)) {
return x;
}
if (x->isConstant()) {
return std::make_shared<Expression>(std::atanh(x->value));
}

return std::make_shared<Expression>(
Linearity::Nonlinear, [](const double val, double) { return std::atanh(val); },
[](const double val, double, const double parent_adjoint) { return parent_adjoint / (1 - val * val); },
[](const ExpressionPtr& expr, const ExpressionPtr&, const ExpressionPtr& parent_adjoint) {
return parent_adjoint / (std::make_shared<Expression>(1.0) - expr * expr);
},
x);
}
} // namespace suboptimal
29 changes: 29 additions & 0 deletions src/autodiff/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ Variable log(const Variable& x) {
return log(x.expr);
}

Variable erf(const Variable& x) {
return erf(x.expr);
}

Variable sin(const Variable& x) {
return sin(x.expr);
}
Expand All @@ -83,4 +87,29 @@ Variable acos(const Variable& x) {
Variable atan(const Variable& x) {
return atan(x.expr);
}

Variable sinh(const Variable& x) {
return sinh(x.expr);
}

Variable cosh(const Variable& x) {
return cosh(x.expr);
}

Variable tanh(const Variable& x) {
return tanh(x.expr);
}

Variable asinh(const Variable& x) {
return asinh(x.expr);
}

Variable acosh(const Variable& x) {
return acosh(x.expr);

}

Variable atanh(const Variable& x) {
return atanh(x.expr);
}
} // namespace suboptimal
132 changes: 123 additions & 9 deletions test/autodiff/Variable_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ TEST_CASE("Autodiff - Variable basic arithmetic", "[autodiff]") {
}
}

TEST_CASE("Autodiff - Variable STL functions", "[autodiff]") {
TEST_CASE("Autodiff - Basic Variable STL functions", "[autodiff]") {
Variable x{};
Variable y{};
Variable f{};
Expand All @@ -66,17 +66,16 @@ TEST_CASE("Autodiff - Variable STL functions", "[autodiff]") {
x.setValue(x_val);
y.setValue(y_val);

f = suboptimal::sin(x) + suboptimal::cos(y) +
suboptimal::atan2(x, 5) * suboptimal::sqrt(y) / (suboptimal::exp(x) + suboptimal::hypot(suboptimal::log(x), y));
f = suboptimal::erf(suboptimal::sqrt(suboptimal::abs(-suboptimal::pow(x, 3.0)))) / suboptimal::exp(x) +
suboptimal::hypot(suboptimal::log(x), y);
const double f_val =
std::sin(x_val) + std::cos(y_val) +
std::atan2(x_val, 5) * std::sqrt(y_val) / (std::exp(x_val) + std::hypot(std::log(x_val), y_val));
std::erf(std::sqrt(std::abs(-std::pow(x_val, 3.0)))) / std::exp(x_val) + std::hypot(std::log(x_val), y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}

SECTION("Value update") {
f = suboptimal::sin(x) + suboptimal::cos(y) +
suboptimal::atan2(x, 5) * suboptimal::sqrt(y) / (suboptimal::exp(x) + suboptimal::hypot(suboptimal::log(x), y));
f = suboptimal::erf(suboptimal::sqrt(suboptimal::abs(-suboptimal::pow(x, 3.0)))) / suboptimal::exp(x) +
suboptimal::hypot(suboptimal::log(x), y);

const double x_val = GENERATE(take(5, random(0.0, 100.0)));
const double y_val = GENERATE(take(5, random(0.0, 100.0)));
Expand All @@ -85,8 +84,123 @@ TEST_CASE("Autodiff - Variable STL functions", "[autodiff]") {
y.setValue(y_val);

const double f_val =
std::sin(x_val) + std::cos(y_val) +
std::atan2(x_val, 5) * std::sqrt(y_val) / (std::exp(x_val) + std::hypot(std::log(x_val), y_val));
std::erf(std::sqrt(std::abs(-std::pow(x_val, 3.0)))) / std::exp(x_val) + std::hypot(std::log(x_val), y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}
}

TEST_CASE("Autodiff - Variable trig functions", "[autodiff]") {
Variable x{};
Variable y{};
Variable f{};

SECTION("Initial value") {
const double x_val = GENERATE(take(5, random(0.0, 100.0)));
const double y_val = GENERATE(take(5, random(0.0, 100.0)));

x.setValue(x_val);
y.setValue(y_val);

f = suboptimal::sin(x * y) + suboptimal::cos(x) + suboptimal::tan(y);
const double f_val = std::sin(x_val * y_val) + std::cos(x_val) + std::tan(y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}
}

TEST_CASE("Autodiff - Variable inverse trig functions", "[autodiff]") {
Variable x{};
Variable y{};
Variable f{};

SECTION("Initial value") {
const double x_val = GENERATE(take(5, random(-1.0, 1.0)));
const double y_val = GENERATE(take(5, random(-1.0, 1.0)));

x.setValue(x_val);
y.setValue(y_val);

f = suboptimal::asin(x) + suboptimal::acos(y) + suboptimal::atan(x * y);
const double f_val = std::asin(x_val) + std::acos(y_val) + std::atan(x_val * y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}

SECTION("Value update") {
f = suboptimal::asin(x) + suboptimal::acos(y) + suboptimal::atan(x * y);

const double x_val = GENERATE(take(5, random(-1.0, 1.0)));
const double y_val = GENERATE(take(5, random(-1.0, 1.0)));

x.setValue(x_val);
y.setValue(y_val);

const double f_val = std::asin(x_val) + std::acos(y_val) + std::atan(x_val * y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}
}

TEST_CASE("Autodiff - Variable hyperbolic trig functions", "[autodiff]") {
Variable x{};
Variable y{};
Variable f{};

SECTION("Initial value") {
const double x_val = GENERATE(take(5, random(-100.0, 100.0)));
const double y_val = GENERATE(take(5, random(-100.0, 100.0)));

x.setValue(x_val);
y.setValue(y_val);

f = suboptimal::sinh(x) + suboptimal::cosh(y) + suboptimal::tanh(x * y);
const double f_val = std::sinh(x_val) + std::cosh(y_val) + std::tanh(x_val * y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}

SECTION("Value update") {
f = suboptimal::sinh(x) + suboptimal::cosh(y) + suboptimal::tanh(x * y);

const double x_val = GENERATE(take(5, random(-100.0, 100.0)));
const double y_val = GENERATE(take(5, random(-100.0, 100.0)));

x.setValue(x_val);
y.setValue(y_val);

const double f_val = std::sinh(x_val) + std::cosh(y_val) + std::tanh(x_val * y_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}
}

TEST_CASE("Autodiff - Variable inverse hyperbolic trig functions") {
Variable x{};
Variable y{};
Variable z{};
Variable f{};

SECTION("Initial value") {
const double x_val = GENERATE(take(5, random(-100.0, 100.0)));
const double y_val = GENERATE(take(5, random(1.0, 100.0)));
const double z_val = GENERATE(take(5, random(-1.0, 1.0)));

x.setValue(x_val);
y.setValue(y_val);
z.setValue(z_val);

f = suboptimal::asinh(x) + suboptimal::acosh(y) + suboptimal::atanh(z);
const double f_val = std::asinh(x_val) + std::acosh(y_val) + std::atanh(z_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}

SECTION("Value update") {
f = suboptimal::asinh(x) + suboptimal::acosh(y) + suboptimal::atanh(z);

const double x_val = GENERATE(take(5, random(-100.0, 100.0)));
const double y_val = GENERATE(take(5, random(1.0, 100.0)));
const double z_val = GENERATE(take(5, random(-1.0, 1.0)));

x.setValue(x_val);
y.setValue(y_val);
z.setValue(z_val);

const double f_val = std::asinh(x_val) + std::acosh(y_val) + std::atanh(z_val);
CHECK_THAT(f.getValue(), Catch::Matchers::WithinAbs(f_val, 1e-9));
}
}
Expand Down
Loading

0 comments on commit 23962a6

Please sign in to comment.