Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Root Solver: ITP Method #544

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/argmin/src/solver/brent/brentroot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct BrentRoot<F> {

impl<F: ArgminFloat> BrentRoot<F> {
/// Constructor
/// The values `min` and `max` must bracketing the root of the function.
/// The values `min` and `max` must be bracketing the root of the function.
/// The parameter `tol` specifies the relative error to be targeted.
pub fn new(min: F, max: F, tol: F) -> Self {
BrentRoot {
Expand Down
354 changes: 354 additions & 0 deletions crates/argmin/src/solver/itp/itp_method.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use crate::core::{
ArgminFloat, CostFunction, Error, IterState, Problem, Solver, State, TerminationReason, KV,
};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use thiserror::Error;

/// Error to be thrown if ITP method is initialized with improper parameters.
#[derive(Debug, Error)]
pub enum ItpRootError {
/// f(min) and f(max) must have different signs
#[error("ItpRoot error: f(min) and f(max) must have different signs.")]
WrongSign,
/// tol must be positive
#[error("ItpRoot error: tol must be positive.")]
NegativeTol,
/// tol must be nonzero
#[error("ItpRoot error: tol must be nonzero.")]
ZeroTol,
/// max must be larger than min
#[error("ItpRoot error: max must be larger than min.")]
MinLargerThanMax,
}

/// # ITP method
///
/// A root-finding algorithm, short for "interpolate, truncate, and project",
/// that achieves superlinear convergence while retaining the worst-case
/// performance of the bisection method.
///
/// ## Requirements on the optimization problem
///
/// The optimization problem is required to implement [`CostFunction`].
///
/// ## Reference
///
/// [ITP Method]: https://en.wikipedia.org/wiki/ITP_Method
/// [An Enhancement of the Bisection Method Average Performance Preserving Minmax Optimality]: https://dl.acm.org/doi/10.1145/3423597
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ItpRoot<F> {
/// required relative accuracy
tol: F,
/// tuned hyper-parameter 1, controls truncation size
kappa1: F,
/// tuned hyper-parameter 2, controls truncation size
kappa2: F,
/// tuned hyper-parameter slack variable, controls projection interval size
n0: F,
/// left boundary of current interval
a: F,
/// right boundary of current interval
b: F,
/// function value at `a`
fa: F,
/// function value at `b`
fb: F,
/// iteration counter, used in the projection step
j: F,
/// n_(1/2), a preprocessing variable
n1o2: F,
/// a preprocessing variable
nmax: F,
}

impl<F: ArgminFloat> ItpRoot<F> {
/// Constructor
/// The values `min` and `max` must be bracketing the root of the function.
/// The parameter `tol` specifies the relative error to be targeted.
/// The values `kappa1` and `kappa2` are hyper-parameters tuning the truncation size.
/// The parameter `n0` is a hyper-parameter slack variable controlling the projection interval
/// size.
pub fn new(min: F, max: F, tol: F, kappa1: F, kappa2: F, n0: F) -> Self {
ItpRoot {
tol,
kappa1,
kappa2,
n0,
a: min,
b: max,
fa: F::nan(),
fb: F::nan(),
j: F::zero(),
n1o2: F::nan(),
nmax: F::nan(),
}
}

/// Constructor with default hyperparameters
/// The values `min` and `max` must be bracketing the root of the function.
/// The parameter `tol` specifies the relative error to be targeted.
/// kappa1 is defaulted to 0.2 / (max - min).
/// kappa2 is defaulted to 2.0.
/// n0 is defaulted to 1.0.
pub fn from_defaults(min: F, max: F, tol: F) -> Self {
Self::new(
min,
max,
tol,
// kappa1, suggested from paper
float!(0.2) / (max - min),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering, maybe we want to have this function return Result and verify max - min cannot be zero so we don't panic here. Not sure though, maybe not worth it right now.

// kappa2
float!(2.0),
// n0
float!(1.0),
)
}
}

impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for ItpRoot<F>
where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
{
fn name(&self) -> &str {
"ItpRoot"
}

fn init(
&mut self,
problem: &mut Problem<O>,
// ItpRoot maintains its own state
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
self.fa = problem.cost(&self.a)?;
self.fb = problem.cost(&self.b)?;
if self.fa * self.fb > float!(0.0) {
return Err(ItpRootError::WrongSign.into());
}
if self.tol < F::zero() {
return Err(ItpRootError::NegativeTol.into());
}
// This helps ensure the log evaluation is stable
if self.a > self.b {
return Err(ItpRootError::MinLargerThanMax.into());
}
// It's important to check this to verify n1o2 doesn't panic
if self.tol.is_zero() {
return Err(ItpRootError::ZeroTol.into());
}

// Preprocessing
self.n1o2 = ((self.b - self.a) / (float!(2.0) * self.tol)).log2();
self.nmax = self.n1o2 + self.n0;

Ok((state.param(self.b).cost(self.fb.abs()), None))
}

fn next_iter(
&mut self,
problem: &mut Problem<O>,
// ItpRoot maintains its own state
state: IterState<F, (), (), (), (), F>,
) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
// Note: the headers here match the steps outlined on Wikipedia, with variable names to
// match for clarity.

// Calculating Parameters
let b_minus_a = self.b - self.a;
let x1o2 = (self.a + self.b) * float!(0.5);
let r = self.tol * float!(2.0).powf(self.nmax + self.j) - b_minus_a * float!(0.5);
let delta = self.kappa1 * b_minus_a.powf(self.kappa2);

// Interpolation
let xf = (self.fb * self.a - self.fa * self.b) / (self.fb - self.fa);

// Truncation
let bisect_falsi_delta = x1o2 - xf;
let sigma = bisect_falsi_delta.signum();
let xt = if delta <= bisect_falsi_delta.abs() {
xf + sigma * delta
} else {
x1o2
};

// Projection
let xitp = if (xt - x1o2).abs() <= r {
xt
} else {
x1o2 - sigma * r
};

// Updating Interval
let fitp = problem.cost(&xitp)?;
if fitp > F::zero() {
self.b = xitp;
self.fb = fitp;
} else if fitp < F::zero() {
self.a = xitp;
self.fa = fitp;
} else {
self.a = xitp;
self.b = xitp;
}
self.j = self.j + float!(1.0);

// Solver loop termination
if (self.b - self.a) <= (float!(2.0) * self.tol) {
let sol = (self.a + self.b) * float!(0.5);
// TODO: This function evaluation serves no purpose other than to serve argmin's cost
// method on the state. It feels wasteful.
let f_sol = problem.cost(&sol)?;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bothers me, but maybe someone smarter than me might know a clever way to not do this.

return Ok((
state
.terminate_with(TerminationReason::SolverConverged)
.param(sol)
.cost(f_sol.abs()),
None,
));
}

Ok((state.param(self.b).cost(self.fb.abs()), None))
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::core::Executor;
use approx::assert_relative_eq;

#[derive(Clone)]
struct Quadratic {}

impl CostFunction for Quadratic {
type Param = f64;
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
Ok(param.powi(2) - 1.0) // x^2 - 1
}
}

// This polynomial matches what is explored in the Wikipedia example
#[derive(Clone)]
struct Polynomial {}

impl CostFunction for Polynomial {
type Param = f64;
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
Ok(param.powi(3) - param - 2.0) // x^3 - x - 2
}
}

#[test]
fn test_itp_negative_tol() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = -1e-6;

let mut solver: ItpRoot<f64> = ItpRoot::from_defaults(min, max, tol);
let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});

let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
solver.init(&mut problem, IterState::new());

// Check if the initialization fails and we get the correct error message
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"ItpRoot error: tol must be positive."
);
}

#[test]
fn test_itp_invalid_range() {
let min: f64 = 2.0;
let max: f64 = 3.0;
let tol: f64 = 1e-6;

let mut solver: ItpRoot<f64> = ItpRoot::from_defaults(min, max, tol);
let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});

let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
solver.init(&mut problem, IterState::new());

// Check if the initialization fails and we get the correct error message
assert!(result.is_err());
assert_eq!(
result.err().unwrap().to_string(),
"ItpRoot error: f(min) and f(max) must have different signs."
);
}

#[test]
fn test_itp_valid_range() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = 1e-6;

let mut solver: ItpRoot<f64> = ItpRoot::from_defaults(min, max, tol);
let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});

let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
solver.init(&mut problem, IterState::new());

// Check if the initialization is successful
assert!(result.is_ok());
}

#[test]
fn test_itp_find_quadratic_root() {
let min: f64 = 0.0;
let max: f64 = 2.0;
let tol: f64 = 1e-6;
let init_param: f64 = 1.5;

let solver: ItpRoot<f64> = ItpRoot::from_defaults(min, max, tol);
let problem: Quadratic = Quadratic {};

let res = Executor::new(problem, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();

// Check if the result is close to the real root
assert_relative_eq!(res.state.best_param.unwrap(), 1.0, epsilon = tol);
}

#[test]
fn test_itp_find_polynomial_root() {
let min: f64 = 1.0;
let max: f64 = 2.0;
let tol: f64 = 0.0005;
let kappa1: f64 = 0.1;
let kappa2: f64 = 2.0;
let n0: f64 = 1.0;
let init_param: f64 = 1.5;

let solver: ItpRoot<f64> = ItpRoot::new(min, max, tol, kappa1, kappa2, n0);
let problem: Polynomial = Polynomial {};

let res = Executor::new(problem, solver)
.configure(|state| state.param(init_param).max_iters(100))
.run()
.unwrap();

// Check if the result is close to the real root
assert_relative_eq!(
res.state.best_param.unwrap(),
1.52138301273268,
epsilon = tol
);
}
}
23 changes: 23 additions & 0 deletions crates/argmin/src/solver/itp/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! # ITP method
//!
//! ## ItpRoot
//!
//! A root-finding algorithm, short for "interpolate, truncate, and project",
//! that achieves superlinear convergence while retaining the worst-case
//! performance of the bisection method.
//!
//! ### References
//!
//! [ITP Method]: https://en.wikipedia.org/wiki/ITP_Method
//! [An Enhancement of the Bisection Method Average Performance Preserving Minmax Optimality]: https://dl.acm.org/doi/10.1145/3423597

mod itp_method;

pub use itp_method::ItpRoot;
1 change: 1 addition & 0 deletions crates/argmin/src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod conjugategradient;
pub mod gaussnewton;
pub mod goldensectionsearch;
pub mod gradientdescent;
pub mod itp;
pub mod landweber;
pub mod linesearch;
pub mod neldermead;
Expand Down
Loading