Skip to content

Commit

Permalink
Fixed handling of residuals in GaussNewton
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Jan 17, 2024
1 parent 2a49597 commit 95d4732
Showing 1 changed file with 136 additions and 31 deletions.
167 changes: 136 additions & 31 deletions argmin/src/solver/gaussnewton/gaussnewton_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,50 @@ where
{
const NAME: &'static str = "Gauss-Newton method";

fn next_iter(
fn init(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, (), J, (), R, F>,
mut state: IterState<P, (), J, (), R, F>,
) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
let param = state.get_param().ok_or_else(argmin_error_closure!(
let init_param = state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`GaussNewton` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let residuals = problem.apply(param)?;
let residuals = problem.apply(&init_param)?;
let cost = residuals.l2_norm();
Ok((
state.param(init_param).residuals(residuals).cost(cost),
None,
))
}

fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, (), J, (), R, F>,
) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
let param = state.get_param().ok_or_else(argmin_error_closure!(
PotentialBug,
"`GaussNewton`: `param` not set"
))?;
let residuals = state.get_residuals().ok_or_else(argmin_error_closure!(
PotentialBug,
"`GaussNewton`: `residuals` not set"
))?;
let jacobian = problem.jacobian(param)?;

let p = jacobian
.clone()
.t()
.dot(&jacobian)
.inv()?
.dot(&jacobian.t().dot(&residuals));
.dot(&jacobian.t().dot(residuals));

let new_param = param.sub(&p.mul(&self.gamma));
let residuals = problem.apply(&new_param)?;

let cost = residuals.l2_norm();

Expand Down Expand Up @@ -239,7 +260,7 @@ mod tests {

#[cfg(feature = "_ndarrayl")]
#[test]
fn test_next_iter_param_not_initialized() {
fn test_init_param_not_initialized() {
use ndarray::{Array, Array1, Array2};

struct TestProblem {}
Expand All @@ -263,7 +284,7 @@ mod tests {
}

let mut gn = GaussNewton::<f64>::new();
let res = gn.next_iter(&mut Problem::new(TestProblem {}), IterState::new());
let res = gn.init(&mut Problem::new(TestProblem {}), IterState::new());
assert_error!(
res,
ArgminError,
Expand All @@ -274,10 +295,90 @@ mod tests {
);
}

#[cfg(feature = "_ndarrayl")]
#[test]
fn test_next_iter_param_not_initialized() {
use ndarray::{Array, Array1, Array2};

struct TestProblem {}

impl Operator for TestProblem {
type Param = Array1<f64>;
type Output = Array1<f64>;

fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
Ok(Array1::from_vec(vec![0.5, 2.0]))
}
}

impl Jacobian for TestProblem {
type Param = Array1<f64>;
type Jacobian = Array2<f64>;

fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
}
}

let mut gn = GaussNewton::<f64>::new();
let res = gn.next_iter(&mut Problem::new(TestProblem {}), IterState::new());
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`GaussNewton`: ",
"`param` not set\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}

#[cfg(feature = "_ndarrayl")]
#[test]
fn test_next_iter_residual_not_initialized() {
use ndarray::{Array, Array1, Array2};

struct TestProblem {}

impl Operator for TestProblem {
type Param = Array1<f64>;
type Output = Array1<f64>;

fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
Ok(Array1::from_vec(vec![0.5, 2.0]))
}
}

impl Jacobian for TestProblem {
type Param = Array1<f64>;
type Jacobian = Array2<f64>;

fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
}
}

let mut gn = GaussNewton::<f64>::new();
let res = gn.next_iter(
&mut Problem::new(TestProblem {}),
IterState::new().param(vec![1f64, 2.0, 3.0, 4.0].into()),
);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`GaussNewton`: ",
"`residuals` not set\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}

#[cfg(feature = "_ndarrayl")]
#[test]
fn test_solver() {
use crate::core::State;
use approx::assert_relative_eq;
use ndarray::{Array, Array1, Array2};
use std::cell::RefCell;

Expand Down Expand Up @@ -316,34 +417,36 @@ mod tests {
let solver: GaussNewton<f64> = GaussNewton::new();
let init_param = Array1::from_vec(vec![0.0, 0.0]);

let param = Executor::new(problem, solver)
let state = Executor::new(problem, solver)
.configure(|config| config.param(init_param).max_iters(1))
.run()
.unwrap()
.state
.get_best_param()
.unwrap()
.clone();
.state;
let param = state.get_best_param().unwrap().clone();
assert_relative_eq!(param[0], -1.0, epsilon = f64::EPSILON.sqrt());
assert_relative_eq!(param[1], 0.25, epsilon = f64::EPSILON.sqrt());

// Assert that cost matches residual:
assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());

// Two iterations, starting from [0, 0], gamma = 1
let problem = Problem {
counter: RefCell::new(0),
};
let solver: GaussNewton<f64> = GaussNewton::new();
let init_param = Array1::from_vec(vec![0.0, 0.0]);

let param = Executor::new(problem, solver)
let state = Executor::new(problem, solver)
.configure(|config| config.param(init_param).max_iters(2))
.run()
.unwrap()
.state
.get_best_param()
.unwrap()
.clone();
assert_relative_eq!(param[0], -1.4, epsilon = f64::EPSILON.sqrt());
assert_relative_eq!(param[1], 0.3, epsilon = f64::EPSILON.sqrt());
.state;
let param = state.get_best_param().unwrap().clone();
assert_relative_eq!(param[0], -1.0, epsilon = f64::EPSILON.sqrt());
assert_relative_eq!(param[1], 0.25, epsilon = f64::EPSILON.sqrt());

// Assert that cost matches residual:
assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());

// Single iteration, starting from [0, 0], gamma = 0.5
let problem = Problem {
Expand All @@ -352,33 +455,35 @@ mod tests {
let solver: GaussNewton<f64> = GaussNewton::new().with_gamma(0.5).unwrap();
let init_param = Array1::from_vec(vec![0.0, 0.0]);

let param = Executor::new(problem, solver)
let state = Executor::new(problem, solver)
.configure(|config| config.param(init_param).max_iters(1))
.run()
.unwrap()
.state
.get_best_param()
.unwrap()
.clone();
.state;
let param = state.get_best_param().unwrap().clone();
assert_relative_eq!(param[0], -0.5, epsilon = f64::EPSILON.sqrt());
assert_relative_eq!(param[1], 0.125, epsilon = f64::EPSILON.sqrt());

// Assert that cost matches residual:
assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());

// Two iterations, starting from [0, 0], gamma = 0.5
let problem = Problem {
counter: RefCell::new(0),
};
let solver: GaussNewton<f64> = GaussNewton::new().with_gamma(0.5).unwrap();
let init_param = Array1::from_vec(vec![0.0, 0.0]);

let param = Executor::new(problem, solver)
let state = Executor::new(problem, solver)
.configure(|config| config.param(init_param).max_iters(2))
.run()
.unwrap()
.state
.get_best_param()
.unwrap()
.clone();
assert_relative_eq!(param[0], -0.7, epsilon = f64::EPSILON.sqrt());
assert_relative_eq!(param[1], 0.15, epsilon = f64::EPSILON.sqrt());
.state;
let param = state.get_best_param().unwrap().clone();
assert_relative_eq!(param[0], -0.5, epsilon = f64::EPSILON.sqrt());
assert_relative_eq!(param[1], 0.125, epsilon = f64::EPSILON.sqrt());

// Assert that cost matches residual:
assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
}
}

0 comments on commit 95d4732

Please sign in to comment.