From 7af653fbc4d97d9791610c650a303b7e54712a9c Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Wed, 17 Jan 2024 08:24:05 +0100 Subject: [PATCH] Added residuals handling to GaussNewton, fixed residuals related issues throughout the codebase --- argmin/src/core/checkpointing/file.rs | 5 +- argmin/src/core/executor.rs | 24 ++- argmin/src/core/observers/mod.rs | 8 +- argmin/src/core/result.rs | 12 +- argmin/src/core/solver.rs | 12 +- argmin/src/core/state/iterstate.rs | 203 ++++++++++++------ argmin/src/core/state/linearprogramstate.rs | 4 +- argmin/src/core/state/populationstate.rs | 4 +- argmin/src/solver/conjugategradient/cg.rs | 3 +- .../solver/conjugategradient/nonlinear_cg.rs | 2 +- .../solver/gaussnewton/gaussnewton_method.rs | 12 +- .../solver/gradientdescent/steepestdescent.rs | 2 +- argmin/src/solver/linesearch/backtracking.rs | 10 +- argmin/src/solver/neldermead/mod.rs | 10 +- argmin/src/solver/newton/newton_cg.rs | 11 +- argmin/src/solver/newton/newton_method.rs | 6 +- argmin/src/solver/quasinewton/bfgs.rs | 10 +- argmin/src/solver/quasinewton/dfp.rs | 10 +- argmin/src/solver/quasinewton/lbfgs.rs | 8 +- argmin/src/solver/quasinewton/sr1.rs | 10 +- .../src/solver/quasinewton/sr1_trustregion.rs | 10 +- argmin/src/solver/simulatedannealing/mod.rs | 4 +- argmin/src/solver/trustregion/cauchypoint.rs | 4 +- argmin/src/solver/trustregion/dogleg.rs | 4 +- argmin/src/solver/trustregion/steihaug.rs | 6 +- .../solver/trustregion/trustregion_method.rs | 4 +- media/book/src/implementing_solver.md | 6 +- 27 files changed, 246 insertions(+), 158 deletions(-) diff --git a/argmin/src/core/checkpointing/file.rs b/argmin/src/core/checkpointing/file.rs index 5ac02cff2..d5d4cebd3 100644 --- a/argmin/src/core/checkpointing/file.rs +++ b/argmin/src/core/checkpointing/file.rs @@ -181,11 +181,12 @@ mod tests { #[allow(clippy::type_complexity)] fn test_save() { let solver = TestSolver::new(); - let state: IterState, (), (), (), f64> = IterState::new().param(vec![1.0f64, 0.0]); + let state: IterState, (), (), (), (), f64> = + IterState::new().param(vec![1.0f64, 0.0]); let check = FileCheckpoint::new("checkpoints", "solver", CheckpointingFrequency::Always); check.save_cond(&solver, &state, 20).unwrap(); - let _loaded: Option<(TestSolver, IterState, (), (), (), f64>)> = + let _loaded: Option<(TestSolver, IterState, (), (), (), (), f64>)> = check.load().unwrap(); } } diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index d11767f71..843b80713 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -408,8 +408,9 @@ mod tests { let problem = TestProblem::new(); let solver = TestSolver::new(); - let mut executor = Executor::new(problem, solver) - .configure(|config: IterState, (), (), (), f64>| config.param(vec![0.0, 0.0])); + let mut executor = Executor::new(problem, solver).configure( + |config: IterState, (), (), (), (), f64>| config.param(vec![0.0, 0.0]), + ); // 1) Parameter vector changes, but not cost (continues to be `Inf`) let new_param = vec![1.0, 1.0]; @@ -492,8 +493,9 @@ mod tests { // 4) `-Inf` is better than `Inf` let solver = TestSolver {}; - let mut executor = Executor::new(problem, solver) - .configure(|config: IterState, (), (), (), f64>| config.param(vec![0.0, 0.0])); + let mut executor = Executor::new(problem, solver).configure( + |config: IterState, (), (), (), (), f64>| config.param(vec![0.0, 0.0]), + ); let new_param = vec![1.0, 1.0]; let new_cost = std::f64::NEG_INFINITY; @@ -584,7 +586,7 @@ mod tests { } // Implement Solver for OptimizationAlgorithm - impl Solver> for OptimizationAlgorithm + impl Solver> for OptimizationAlgorithm where O: CostFunction, P: Clone, @@ -596,8 +598,8 @@ mod tests { fn init( &mut self, _problem: &mut Problem, - state: IterState, - ) -> Result<(IterState, Option), Error> { + state: IterState, + ) -> Result<(IterState, Option), Error> { self.internal_state = 1; Ok((state, None)) } @@ -606,21 +608,21 @@ mod tests { fn next_iter( &mut self, _problem: &mut Problem, - state: IterState, - ) -> Result<(IterState, Option), Error> { + state: IterState, + ) -> Result<(IterState, Option), Error> { self.internal_state += 1; Ok((state, None)) } // Avoid terminating early because param does not change - fn terminate(&mut self, _state: &IterState) -> TerminationStatus { + fn terminate(&mut self, _state: &IterState) -> TerminationStatus { TerminationStatus::NotTerminated } // Avoid terminating early because param does not change fn terminate_internal( &mut self, - state: &IterState, + state: &IterState, ) -> TerminationStatus { if state.get_iter() >= state.get_max_iters() { TerminationStatus::Terminated(TerminationReason::MaxItersReached) diff --git a/argmin/src/core/observers/mod.rs b/argmin/src/core/observers/mod.rs index 9b26fd945..ae2487c5b 100644 --- a/argmin/src/core/observers/mod.rs +++ b/argmin/src/core/observers/mod.rs @@ -198,7 +198,7 @@ impl Observers { /// use argmin::core::observers::Observers; /// use argmin::core::IterState; /// - /// let observers: Observers, (), (), (), f64>> = Observers::new(); + /// let observers: Observers, (), (), (), (), f64>> = Observers::new(); /// # assert!(observers.is_empty()); /// ``` pub fn new() -> Self { @@ -214,7 +214,7 @@ impl Observers { /// use argmin_observer_slog::SlogLogger; /// use argmin::core::IterState; /// - /// let mut observers: Observers, (), (), (), f64>> = Observers::new(); + /// let mut observers: Observers, (), (), (), (), f64>> = Observers::new(); /// /// let logger = SlogLogger::term(); /// observers.push(logger, ObserverMode::Always); @@ -237,7 +237,7 @@ impl Observers { /// use argmin::core::observers::Observers; /// use argmin::core::IterState; /// - /// let observers: Observers, (), (), (), f64>> = Observers::new(); + /// let observers: Observers, (), (), (), (), f64>> = Observers::new(); /// assert!(observers.is_empty()); /// ``` pub fn is_empty(&self) -> bool { @@ -375,7 +375,7 @@ mod tests { let storages = [test_stor_1, test_stor_2, test_stor_3, test_stor_4]; - type TState = IterState, (), (), (), f64>; + type TState = IterState, (), (), (), (), f64>; let mut obs: Observers = Observers::new(); obs.push(test_obs_1, ObserverMode::Never) diff --git a/argmin/src/core/result.rs b/argmin/src/core/result.rs index 00cfbe391..95f5a7214 100644 --- a/argmin/src/core/result.rs +++ b/argmin/src/core/result.rs @@ -39,7 +39,7 @@ impl OptimizationResult { /// # struct SomeSolver {} /// # /// let rosenbrock = Rosenbrock::new(); - /// let state: IterState, (), (), (), f64> = IterState::new(); + /// let state: IterState, (), (), (), (), f64> = IterState::new(); /// let solver = SomeSolver {}; /// /// let result = OptimizationResult::new(Problem::new(rosenbrock), solver, state); @@ -65,7 +65,7 @@ impl OptimizationResult { /// # struct Rosenbrock {} /// # let solver = (); /// # - /// # let state: IterState, (), (), (), f64> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # /// # let result = OptimizationResult::new(Problem::new(Rosenbrock {}), solver, state); /// # @@ -85,7 +85,7 @@ impl OptimizationResult { /// # struct Rosenbrock {} /// # let solver = (); /// # - /// # let state: IterState, (), (), (), f64> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # /// # let result = OptimizationResult::new(Problem::new(Rosenbrock {}), solver, state); /// # @@ -105,11 +105,11 @@ impl OptimizationResult { /// # struct Rosenbrock {} /// # let solver = (); /// # - /// # let state: IterState, (), (), (), f64> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # /// # let result = OptimizationResult::new(Problem::new(Rosenbrock {}), solver, state); /// # - /// let state: &IterState, (), (), (), f64> = result.state(); + /// let state: &IterState, (), (), (), (), f64> = result.state(); /// ``` pub fn state(&self) -> &I { &self.state @@ -199,7 +199,7 @@ mod tests { send_sync_test!( optimizationresult, - OptimizationResult> + OptimizationResult> ); // TODO: More tests, in particular the checking that the output is as intended. diff --git a/argmin/src/core/solver.rs b/argmin/src/core/solver.rs index a81514ec7..a09915830 100644 --- a/argmin/src/core/solver.rs +++ b/argmin/src/core/solver.rs @@ -30,7 +30,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] /// struct OptimizationAlgorithm {} /// -/// impl Solver> for OptimizationAlgorithm +/// impl Solver> for OptimizationAlgorithm /// where /// O: CostFunction, /// P: Clone, @@ -41,8 +41,8 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// fn init( /// &mut self, /// problem: &mut Problem, -/// state: IterState, -/// ) -> Result<(IterState, Option), Error> { +/// state: IterState, +/// ) -> Result<(IterState, Option), Error> { /// // Initialize algorithm, update `state`. /// // Implementing this method is optional. /// Ok((state, None)) @@ -51,14 +51,14 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// fn next_iter( /// &mut self, /// problem: &mut Problem, -/// state: IterState, -/// ) -> Result<(IterState, Option), Error> { +/// state: IterState, +/// ) -> Result<(IterState, Option), Error> { /// // Compute single iteration of algorithm, update `state`. /// // Implementing this method is required. /// Ok((state, None)) /// } /// -/// fn terminate(&mut self, state: &IterState) -> TerminationStatus { +/// fn terminate(&mut self, state: &IterState) -> TerminationStatus { /// // Check if stopping criteria are met. /// // Implementing this method is optional. /// TerminationStatus::NotTerminated diff --git a/argmin/src/core/state/iterstate.rs b/argmin/src/core/state/iterstate.rs index 5b640446d..1fb7560e1 100644 --- a/argmin/src/core/state/iterstate.rs +++ b/argmin/src/core/state/iterstate.rs @@ -72,6 +72,8 @@ pub struct IterState { pub prev_jacobian: Option, /// Value of residuals from recent call to apply pub residuals: Option, + /// Value of residuals from previous call to apply + pub prev_residuals: Option, /// Current iteration pub iter: u64, /// Iteration number of last best cost @@ -98,7 +100,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State}; - /// # let state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # let param_old = vec![1.0f64, 2.0f64]; /// # let state = state.param(param_old); /// # assert!(state.prev_param.is_none()); @@ -124,7 +126,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State}; - /// # let state: IterState<(), Vec, (), (), f64, ()> = IterState::new(); + /// # let state: IterState<(), Vec, (), (), (), f64> = IterState::new(); /// # let grad_old = vec![1.0f64, 2.0f64]; /// # let state = state.gradient(grad_old); /// # assert!(state.prev_grad.is_none()); @@ -150,7 +152,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State}; - /// # let state: IterState<(), (), (), Vec, f64, ()> = IterState::new(); + /// # let state: IterState<(), (), (), Vec, (), f64> = IterState::new(); /// # let hessian_old = vec![1.0f64, 2.0f64]; /// # let state = state.hessian(hessian_old); /// # assert!(state.prev_hessian.is_none()); @@ -176,7 +178,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State}; - /// # let state: IterState<(), (), (), Vec, f64, ()>> = IterState::new(); + /// # let state: IterState<(), (), (), Vec, (), f64> = IterState::new(); /// # let inv_hessian_old = vec![1.0f64, 2.0f64]; /// # let state = state.inv_hessian(inv_hessian_old); /// # assert!(state.prev_inv_hessian.is_none()); @@ -202,7 +204,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State}; - /// # let state: IterState<(), (), Vec, (), f64, ()> = IterState::new(); + /// # let state: IterState<(), (), Vec, (), (), f64> = IterState::new(); /// # let jacobian_old = vec![1.0f64, 2.0f64]; /// # let state = state.jacobian(jacobian_old); /// # assert!(state.prev_jacobian.is_none()); @@ -229,7 +231,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State}; - /// # let state: IterState<(), (), Vec, (), f64, ()> = IterState::new(); + /// # let state: IterState<(), (), Vec, (), (), f64> = IterState::new(); /// # let cost_old = 1.0f64; /// # let state = state.cost(cost_old); /// # assert_eq!(state.prev_cost.to_ne_bytes(), f64::INFINITY.to_ne_bytes()); @@ -255,7 +257,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.target_cost.to_ne_bytes(), f64::NEG_INFINITY.to_ne_bytes()); /// let state = state.target_cost(0.0); /// # assert_eq!(state.target_cost.to_ne_bytes(), 0.0f64.to_ne_bytes()); @@ -272,7 +274,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.max_iters, std::u64::MAX); /// let state = state.max_iters(1000); /// # assert_eq!(state.max_iters, 1000); @@ -283,13 +285,39 @@ where self } + /// Set residuals. This shifts the stored residuals to the previous residuals. + /// + /// # Example + /// + /// ``` + /// # use argmin::core::{IterState, State}; + /// # let state: IterState<(), (), (), (), Vec, f64> = IterState::new(); + /// # let residuals_old = vec![1.0f64, 2.0f64]; + /// # let state = state.residuals(residuals_old); + /// # assert!(state.prev_residuals.is_none()); + /// # assert_eq!(state.residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); + /// # assert_eq!(state.residuals.as_ref().unwrap()[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); + /// # let residuals = vec![0.0f64, 3.0f64]; + /// let state = state.residuals(residuals); + /// # assert_eq!(state.prev_residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); + /// # assert_eq!(state.prev_residuals.as_ref().unwrap()[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); + /// # assert_eq!(state.residuals.as_ref().unwrap()[0].to_ne_bytes(), 0.0f64.to_ne_bytes()); + /// # assert_eq!(state.residuals.as_ref().unwrap()[1].to_ne_bytes(), 3.0f64.to_ne_bytes()); + /// ``` + #[must_use] + pub fn residuals(mut self, residuals: R) -> Self { + std::mem::swap(&mut self.prev_residuals, &mut self.residuals); + self.residuals = Some(residuals); + self + } + /// Returns the current cost function value /// /// # Example /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let state: IterState, (), (), (), (), f64> = IterState::new(); /// # let state = state.cost(2.0); /// let cost = state.get_cost(); /// # assert_eq!(cost.to_ne_bytes(), 2.0f64.to_ne_bytes()); @@ -304,7 +332,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.prev_cost = 2.0; /// let prev_cost = state.get_prev_cost(); /// # assert_eq!(prev_cost.to_ne_bytes(), 2.0f64.to_ne_bytes()); @@ -319,7 +347,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.best_cost = 2.0; /// let best_cost = state.get_best_cost(); /// # assert_eq!(best_cost.to_ne_bytes(), 2.0f64.to_ne_bytes()); @@ -334,7 +362,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.prev_best_cost = 2.0; /// let prev_best_cost = state.get_prev_best_cost(); /// # assert_eq!(prev_best_cost.to_ne_bytes(), 2.0f64.to_ne_bytes()); @@ -349,7 +377,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.target_cost.to_ne_bytes(), std::f64::NEG_INFINITY.to_ne_bytes()); /// # state.target_cost = 0.0; /// let target_cost = state.get_target_cost(); @@ -365,7 +393,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.take_param().is_none()); /// # let mut state = state.param(vec![1.0, 2.0]); /// # assert_eq!(state.param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -386,7 +414,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.prev_param.is_none()); /// # state.prev_param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.prev_param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -405,7 +433,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.take_prev_param().is_none()); /// # state.prev_param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.prev_param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -426,7 +454,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.prev_best_param.is_none()); /// # state.prev_best_param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.prev_best_param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -445,7 +473,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.take_best_param().is_none()); /// # state.best_param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.best_param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -466,7 +494,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.take_prev_best_param().is_none()); /// # state.prev_best_param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.prev_best_param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -487,7 +515,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), Vec, (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), Vec, (), (), (), f64> = IterState::new(); /// # assert!(state.grad.is_none()); /// # assert!(state.get_gradient().is_none()); /// # state.grad = Some(vec![1.0, 2.0]); @@ -507,7 +535,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), Vec, (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), Vec, (), (), (), f64> = IterState::new(); /// # assert!(state.take_gradient().is_none()); /// # state.grad = Some(vec![1.0, 2.0]); /// # assert_eq!(state.grad.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -528,7 +556,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), Vec, (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), Vec, (), (), (), f64> = IterState::new(); /// # assert!(state.prev_grad.is_none()); /// # assert!(state.get_prev_gradient().is_none()); /// # state.prev_grad = Some(vec![1.0, 2.0]); @@ -548,7 +576,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), Vec, (), (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), Vec, (), (), (), f64> = IterState::new(); /// # assert!(state.take_prev_gradient().is_none()); /// # state.prev_grad = Some(vec![1.0, 2.0]); /// # assert_eq!(state.prev_grad.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -569,7 +597,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.hessian.is_none()); /// # assert!(state.get_hessian().is_none()); /// # state.hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -593,7 +621,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.hessian.is_none()); /// # assert!(state.take_hessian().is_none()); /// # state.hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -619,7 +647,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.prev_hessian.is_none()); /// # assert!(state.get_prev_hessian().is_none()); /// # state.prev_hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -643,7 +671,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.prev_hessian.is_none()); /// # assert!(state.take_prev_hessian().is_none()); /// # state.prev_hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -669,7 +697,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.inv_hessian.is_none()); /// # assert!(state.get_inv_hessian().is_none()); /// # state.inv_hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -693,7 +721,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.inv_hessian.is_none()); /// # assert!(state.take_inv_hessian().is_none()); /// # state.inv_hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -719,7 +747,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.prev_inv_hessian.is_none()); /// # assert!(state.get_prev_inv_hessian().is_none()); /// # state.prev_inv_hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -743,7 +771,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), Vec>, f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), (), Vec>, (), f64> = IterState::new(); /// # assert!(state.prev_inv_hessian.is_none()); /// # assert!(state.take_prev_inv_hessian().is_none()); /// # state.prev_inv_hessian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -769,7 +797,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), Vec>, (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), Vec>, (), (), f64> = IterState::new(); /// # assert!(state.jacobian.is_none()); /// # assert!(state.get_jacobian().is_none()); /// # state.jacobian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -793,7 +821,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), Vec>, (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), Vec>, (), (), f64> = IterState::new(); /// # assert!(state.jacobian.is_none()); /// # assert!(state.take_jacobian().is_none()); /// # state.jacobian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -819,7 +847,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), Vec>, (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), Vec>, (), (), f64> = IterState::new(); /// # assert!(state.prev_jacobian.is_none()); /// # assert!(state.get_prev_jacobian().is_none()); /// # state.prev_jacobian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -843,7 +871,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), Vec>, (), f64, ()> = IterState::new(); + /// # let mut state: IterState<(), (), Vec>, (), (), f64> = IterState::new(); /// # assert!(state.prev_jacobian.is_none()); /// # assert!(state.take_prev_jacobian().is_none()); /// # state.prev_jacobian = Some(vec![vec![1.0, 2.0], vec![3.0, 4.0]]); @@ -863,13 +891,13 @@ where self.prev_jacobian.take() } - /// Returns a reference to previous parameter vector + /// Returns a reference to the residuals /// /// # Example /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), (), f64, Vec> = IterState::new(); + /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new(); /// # assert!(state.residuals.is_none()); /// # state.residuals = Some(vec![1.0, 2.0]); /// # assert_eq!(state.residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -882,13 +910,13 @@ where self.residuals.as_ref() } - /// Moves the previous parameter vector out and replaces it internally with `None` + /// Moves the residuals out and replaces it internally with `None` /// /// # Example /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState<(), (), (), (), f64, Vec> = IterState::new(); + /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new(); /// # assert!(state.take_residuals().is_none()); /// # state.residuals = Some(vec![1.0, 2.0]); /// # assert_eq!(state.residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -902,6 +930,46 @@ where pub fn take_residuals(&mut self) -> Option { self.residuals.take() } + + /// Returns a reference to the previous residuals + /// + /// # Example + /// + /// ``` + /// # use argmin::core::{IterState, State, ArgminFloat}; + /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new(); + /// # assert!(state.residuals.is_none()); + /// # state.residuals = Some(vec![1.0, 2.0]); + /// # assert_eq!(state.residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); + /// # assert_eq!(state.residuals.as_ref().unwrap()[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); + /// let residuals = state.get_residuals(); // Option<&R> + /// # assert_eq!(residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); + /// # assert_eq!(residuals.as_ref().unwrap()[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); + /// ``` + pub fn get_prev_residuals(&self) -> Option<&R> { + self.prev_residuals.as_ref() + } + + /// Moves the previous residuals out and replaces it internally with `None` + /// + /// # Example + /// + /// ``` + /// # use argmin::core::{IterState, State, ArgminFloat}; + /// # let mut state: IterState<(), (), (), (), Vec, f64> = IterState::new(); + /// # assert!(state.take_residuals().is_none()); + /// # state.residuals = Some(vec![1.0, 2.0]); + /// # assert_eq!(state.residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); + /// # assert_eq!(state.residuals.as_ref().unwrap()[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); + /// let residuals = state.take_residuals(); // Option + /// # assert!(state.take_residuals().is_none()); + /// # assert!(state.residuals.is_none()); + /// # assert_eq!(residuals.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); + /// # assert_eq!(residuals.as_ref().unwrap()[1].to_ne_bytes(), 2.0f64.to_ne_bytes()); + /// ``` + pub fn take_prev_residuals(&mut self) -> Option { + self.prev_residuals.take() + } } impl State for IterState @@ -922,7 +990,7 @@ where /// # extern crate instant; /// # use instant; /// # use argmin::core::{IterState, State, ArgminFloat, TerminationStatus}; - /// let state: IterState, Vec, Vec>, Vec>, f64> = IterState::new(); + /// let state: IterState, Vec, Vec>, Vec>, Vec, f64> = IterState::new(); /// # assert!(state.param.is_none()); /// # assert!(state.prev_param.is_none()); /// # assert!(state.best_param.is_none()); @@ -967,6 +1035,7 @@ where jacobian: None, prev_jacobian: None, residuals: None, + prev_residuals: None, iter: 0, last_best_iter: 0, max_iters: std::u64::MAX, @@ -983,7 +1052,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// let mut state: IterState, (), (), (), f64> = IterState::new(); + /// let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// /// // Simulating a new, better parameter vector /// state.best_param = Some(vec![1.0f64]); @@ -1005,7 +1074,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// let mut state: IterState, (), (), (), f64> = IterState::new(); + /// let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// /// // Simulating a new, better parameter vector /// state.best_param = Some(vec![1.0f64]); @@ -1048,7 +1117,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.param.is_none()); /// # state.param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -1067,7 +1136,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert!(state.best_param.is_none()); /// # state.best_param = Some(vec![1.0, 2.0]); /// # assert_eq!(state.best_param.as_ref().unwrap()[0].to_ne_bytes(), 1.0f64.to_ne_bytes()); @@ -1086,7 +1155,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat, TerminationReason, TerminationStatus}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.termination_status, TerminationStatus::NotTerminated); /// let state = state.terminate_with(TerminationReason::MaxItersReached); /// # assert_eq!(state.termination_status, TerminationStatus::Terminated(TerminationReason::MaxItersReached)); @@ -1104,7 +1173,7 @@ where /// # extern crate instant; /// # use instant; /// # use argmin::core::{IterState, State, ArgminFloat, TerminationReason}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// let state = state.time(Some(instant::Duration::new(0, 12))); /// # assert_eq!(state.time.unwrap(), instant::Duration::new(0, 12)); /// ``` @@ -1119,7 +1188,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.cost = 12.0; /// let cost = state.get_cost(); /// # assert_eq!(cost.to_ne_bytes(), 12.0f64.to_ne_bytes()); @@ -1134,7 +1203,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.best_cost = 12.0; /// let best_cost = state.get_best_cost(); /// # assert_eq!(best_cost.to_ne_bytes(), 12.0f64.to_ne_bytes()); @@ -1149,7 +1218,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.target_cost = 12.0; /// let target_cost = state.get_target_cost(); /// # assert_eq!(target_cost.to_ne_bytes(), 12.0f64.to_ne_bytes()); @@ -1164,7 +1233,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.iter = 12; /// let iter = state.get_iter(); /// # assert_eq!(iter, 12); @@ -1179,7 +1248,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.last_best_iter = 12; /// let last_best_iter = state.get_last_best_iter(); /// # assert_eq!(last_best_iter, 12); @@ -1194,7 +1263,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.max_iters = 12; /// let max_iters = state.get_max_iters(); /// # assert_eq!(max_iters, 12); @@ -1209,7 +1278,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat, TerminationStatus}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// let termination_status = state.get_termination_status(); /// # assert_eq!(*termination_status, TerminationStatus::NotTerminated); /// ``` @@ -1223,7 +1292,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat, TerminationReason}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// let termination_reason = state.get_termination_reason(); /// # assert_eq!(termination_reason, None); /// ``` @@ -1242,7 +1311,7 @@ where /// # extern crate instant; /// # use instant; /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// let time = state.get_time(); /// # assert_eq!(time.unwrap(), instant::Duration::new(0, 0)); /// ``` @@ -1256,7 +1325,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.iter, 0); /// state.increment_iter(); /// # assert_eq!(state.iter, 1); @@ -1270,7 +1339,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{Problem, IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// # @@ -1300,7 +1369,7 @@ where /// ``` /// # use std::collections::HashMap; /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # assert_eq!(state.counts, HashMap::new()); /// # state.counts.insert("test2".to_string(), 10u64); /// let counts = state.get_func_counts(); @@ -1319,7 +1388,7 @@ where /// /// ``` /// # use argmin::core::{IterState, State, ArgminFloat}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # let mut state: IterState, (), (), (), (), f64> = IterState::new(); /// # state.last_best_iter = 12; /// # state.iter = 12; /// let is_best = state.is_best(); @@ -1342,9 +1411,10 @@ mod tests { #[allow(clippy::type_complexity)] fn test_iterstate() { let param = vec![1.0f64, 2.0]; + let residuals = vec![1.0f64, 2.0]; let cost: f64 = 42.0; - let mut state: IterState, Vec, Vec, Vec>, f64> = + let mut state: IterState, Vec, Vec, Vec>, Vec, f64> = IterState::new(); assert!(state.get_param().is_none()); @@ -1382,10 +1452,14 @@ mod tests { assert!(state.get_prev_inv_hessian().is_none()); assert!(state.get_jacobian().is_none()); assert!(state.get_prev_jacobian().is_none()); + assert!(state.get_residuals().is_none()); + assert!(state.get_prev_residuals().is_none()); assert_eq!(state.get_iter(), 0); assert!(state.is_best()); + state = state.residuals(param.clone()); + assert_eq!(state.get_max_iters(), std::u64::MAX); let func_counts = state.get_func_counts().clone(); assert!(!func_counts.contains_key("cost_count")); @@ -1412,6 +1486,13 @@ mod tests { assert_eq!(*state.get_param().unwrap(), new_param); assert_eq!(*state.get_prev_param().unwrap(), param); + let new_residuals = vec![2.0, 1.0]; + + state = state.residuals(new_residuals.clone()); + + assert_eq!(*state.get_residuals().unwrap(), new_residuals); + assert_eq!(*state.get_prev_residuals().unwrap(), residuals); + let new_cost: f64 = 21.0; let mut state = state.cost(new_cost); @@ -1516,6 +1597,8 @@ mod tests { assert_eq!(state.take_prev_inv_hessian().unwrap(), inv_hessian); assert_eq!(state.take_jacobian().unwrap(), new_jacobian); assert_eq!(state.take_prev_jacobian().unwrap(), jacobian); + assert_eq!(*state.get_residuals().unwrap(), new_residuals); + assert_eq!(*state.get_prev_residuals().unwrap(), residuals); let func_counts = state.get_func_counts().clone(); assert!(!func_counts.contains_key("cost_count")); assert!(!func_counts.contains_key("operator_count")); diff --git a/argmin/src/core/state/linearprogramstate.rs b/argmin/src/core/state/linearprogramstate.rs index e9b005ba0..c6593b863 100644 --- a/argmin/src/core/state/linearprogramstate.rs +++ b/argmin/src/core/state/linearprogramstate.rs @@ -456,8 +456,8 @@ where /// # Example /// /// ``` - /// # use argmin::core::{IterState, State, ArgminFloat, TerminationReason}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # use argmin::core::{LinearProgramState, State, ArgminFloat, TerminationReason}; + /// # let mut state: LinearProgramState, f64> = LinearProgramState::new(); /// let termination_reason = state.get_termination_reason(); /// # assert_eq!(termination_reason, None); /// ``` diff --git a/argmin/src/core/state/populationstate.rs b/argmin/src/core/state/populationstate.rs index b3f57ebe8..28525a23e 100644 --- a/argmin/src/core/state/populationstate.rs +++ b/argmin/src/core/state/populationstate.rs @@ -735,8 +735,8 @@ where /// # Example /// /// ``` - /// # use argmin::core::{IterState, State, ArgminFloat, TerminationReason}; - /// # let mut state: IterState, (), (), (), f64> = IterState::new(); + /// # use argmin::core::{PopulationState, State, ArgminFloat, TerminationReason}; + /// # let mut state: PopulationState, f64> = PopulationState::new(); /// let termination_reason = state.get_termination_reason(); /// # assert_eq!(termination_reason, None); /// ``` diff --git a/argmin/src/solver/conjugategradient/cg.rs b/argmin/src/solver/conjugategradient/cg.rs index 79470bbd1..2822c6666 100644 --- a/argmin/src/solver/conjugategradient/cg.rs +++ b/argmin/src/solver/conjugategradient/cg.rs @@ -229,7 +229,8 @@ mod tests { #[test] fn test_init() { let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]); - let state: IterState, (), (), (), f64> = IterState::new().param(vec![3.0, 4.0]); + let state: IterState, (), (), (), (), f64> = + IterState::new().param(vec![3.0, 4.0]); let (state_out, kv) = cg .init(&mut Problem::new(TestProblem::new()), state.clone()) .unwrap(); diff --git a/argmin/src/solver/conjugategradient/nonlinear_cg.rs b/argmin/src/solver/conjugategradient/nonlinear_cg.rs index d42278f91..c0bc5f686 100644 --- a/argmin/src/solver/conjugategradient/nonlinear_cg.rs +++ b/argmin/src/solver/conjugategradient/nonlinear_cg.rs @@ -334,7 +334,7 @@ mod tests { let beta_method = PolakRibiere::new(); let mut nlcg: NonlinearConjugateGradient, _, _, f64> = NonlinearConjugateGradient::new(linesearch, beta_method); - let state: IterState, Vec, (), (), f64> = + let state: IterState, Vec, (), (), (), f64> = IterState::new().param(vec![3.0, 4.0]); let (state_out, kv) = nlcg .init(&mut Problem::new(TestProblem::new()), state.clone()) diff --git a/argmin/src/solver/gaussnewton/gaussnewton_method.rs b/argmin/src/solver/gaussnewton/gaussnewton_method.rs index e0f0ead8b..5b7b9b246 100644 --- a/argmin/src/solver/gaussnewton/gaussnewton_method.rs +++ b/argmin/src/solver/gaussnewton/gaussnewton_method.rs @@ -109,16 +109,16 @@ impl Default for GaussNewton { } } -impl Solver> for GaussNewton +impl Solver> for GaussNewton where - O: Operator + Jacobian, + O: Operator + Jacobian, P: Clone + ArgminSub + ArgminMul, - U: ArgminL2Norm, + R: ArgminL2Norm, J: Clone + ArgminTranspose + ArgminInv + ArgminDot - + ArgminDot + + ArgminDot + ArgminDot, F: ArgminFloat, { @@ -148,7 +148,9 @@ where let new_param = param.sub(&p.mul(&self.gamma)); - Ok((state.param(new_param).cost(residuals.l2_norm()), None)) + let cost = residuals.l2_norm(); + + Ok((state.param(new_param).residuals(residuals).cost(cost), None)) } fn terminate(&mut self, state: &IterState) -> TerminationStatus { diff --git a/argmin/src/solver/gradientdescent/steepestdescent.rs b/argmin/src/solver/gradientdescent/steepestdescent.rs index 735c1095d..06f90952e 100644 --- a/argmin/src/solver/gradientdescent/steepestdescent.rs +++ b/argmin/src/solver/gradientdescent/steepestdescent.rs @@ -55,7 +55,7 @@ where O: CostFunction + Gradient, P: Clone + SerializeAlias + DeserializeOwnedAlias, G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul, - L: Clone + LineSearch + Solver>, + L: Clone + LineSearch + Solver>, F: ArgminFloat, { const NAME: &'static str = "Steepest Descent"; diff --git a/argmin/src/solver/linesearch/backtracking.rs b/argmin/src/solver/linesearch/backtracking.rs index 8bad3e5d9..8562adcf2 100644 --- a/argmin/src/solver/linesearch/backtracking.rs +++ b/argmin/src/solver/linesearch/backtracking.rs @@ -130,7 +130,7 @@ impl BacktrackingLineSearch where P: ArgminScaledAdd, L: LineSearchCondition, - IterState: State, + IterState: State, F: ArgminFloat, { /// Perform a single backtracking step @@ -591,10 +591,10 @@ mod tests { assert_eq!( , Vec, ArmijoCondition, f64> as Solver< TestProblem, - IterState, Vec, (), (), f64>, + IterState, Vec, (), (), (), f64>, >>::terminate( &mut ls, - &IterState::, Vec, (), (), f64>::new().param(init_param) + &IterState::, Vec, (), (), (), f64>::new().param(init_param) ), TerminationStatus::Terminated(TerminationReason::SolverConverged) ); @@ -605,10 +605,10 @@ mod tests { assert_eq!( , Vec, ArmijoCondition, f64> as Solver< TestProblem, - IterState, Vec, (), (), f64>, + IterState, Vec, (), (), (), f64>, >>::terminate( &mut ls, - &IterState::, Vec, (), (), f64>::new().param(init_param) + &IterState::, Vec, (), (), (), f64>::new().param(init_param) ), TerminationStatus::NotTerminated ); diff --git a/argmin/src/solver/neldermead/mod.rs b/argmin/src/solver/neldermead/mod.rs index f5f95a0da..f530cbd9e 100644 --- a/argmin/src/solver/neldermead/mod.rs +++ b/argmin/src/solver/neldermead/mod.rs @@ -690,7 +690,7 @@ mod tests { (vec![-0.5, 2.0], 0.5f64.powi(2) + 2.0f64.powi(2)), ]; let mut nm: NelderMead<_, f64> = NelderMead::new(params); - let state: IterState, (), (), (), f64> = IterState::new(); + let state: IterState, (), (), (), (), f64> = IterState::new(); let problem = MwProblem {}; let (state_out, kv) = nm.init(&mut Problem::new(problem), state).unwrap(); @@ -721,7 +721,7 @@ mod tests { fn test_next_iter_reflection() { let params: Vec> = vec![vec![-1.0, 0.0], vec![-0.1, 0.65], vec![-0.1, -0.95]]; let mut nm: NelderMead<_, f64> = NelderMead::new(params); - let state: IterState, (), (), (), f64> = IterState::new(); + let state: IterState, (), (), (), (), f64> = IterState::new(); let mut problem = Problem::new(MwProblem {}); let (state, _) = nm.init(&mut problem, state).unwrap(); @@ -764,7 +764,7 @@ mod tests { vec![-1.0, -1.0 - f64::EPSILON], ]; let mut nm: NelderMead<_, f64> = NelderMead::new(params); - let state: IterState, (), (), (), f64> = IterState::new(); + let state: IterState, (), (), (), (), f64> = IterState::new(); let mut problem = Problem::new(MwProblem {}); let (state, _) = nm.init(&mut problem, state).unwrap(); @@ -800,7 +800,7 @@ mod tests { fn test_next_iter_contraction_outside() { let params: Vec> = vec![vec![-1.1, 0.0], vec![-0.1, 1.0], vec![-0.1, -0.5]]; let mut nm: NelderMead<_, f64> = NelderMead::new(params); - let state: IterState, (), (), (), f64> = IterState::new(); + let state: IterState, (), (), (), (), f64> = IterState::new(); let mut problem = Problem::new(MwProblem {}); let (state, _) = nm.init(&mut problem, state).unwrap(); @@ -836,7 +836,7 @@ mod tests { fn test_next_iter_contraction_inside() { let params: Vec> = vec![vec![-1.0, 0.0], vec![0.0, 1.0], vec![0.0, -0.5]]; let mut nm: NelderMead<_, f64> = NelderMead::new(params); - let state: IterState, (), (), (), f64> = IterState::new(); + let state: IterState, (), (), (), (), f64> = IterState::new(); let mut problem = Problem::new(MwProblem {}); let (state, _) = nm.init(&mut problem, state).unwrap(); diff --git a/argmin/src/solver/newton/newton_cg.rs b/argmin/src/solver/newton/newton_cg.rs index 5d9345003..c747028d1 100644 --- a/argmin/src/solver/newton/newton_cg.rs +++ b/argmin/src/solver/newton/newton_cg.rs @@ -106,7 +106,7 @@ where } } -impl Solver> for NewtonCG +impl Solver> for NewtonCG where O: Gradient + Hessian, P: Clone @@ -120,17 +120,16 @@ where + ArgminZeroLike, G: SerializeAlias + DeserializeOwnedAlias + ArgminL2Norm + ArgminMul, H: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminDot, - L: Clone + LineSearch + Solver>, + L: Clone + LineSearch + Solver>, F: ArgminFloat + ArgminL2Norm, - R: Clone + SerializeAlias + DeserializeOwnedAlias, { const NAME: &'static str = "Newton-CG"; fn next_iter( &mut self, problem: &mut Problem, - mut state: IterState, - ) -> Result<(IterState, Option), Error> { + mut state: IterState, + ) -> Result<(IterState, Option), Error> { let param = state.take_param().ok_or_else(argmin_error_closure!( NotInitialized, concat!( @@ -211,7 +210,7 @@ where )) } - fn terminate(&mut self, state: &IterState) -> TerminationStatus { + fn terminate(&mut self, state: &IterState) -> TerminationStatus { if (state.get_cost() - state.get_prev_cost()).abs() < self.tol { TerminationStatus::Terminated(TerminationReason::SolverConverged) } else { diff --git a/argmin/src/solver/newton/newton_method.rs b/argmin/src/solver/newton/newton_method.rs index bd495cb8b..945988b2e 100644 --- a/argmin/src/solver/newton/newton_method.rs +++ b/argmin/src/solver/newton/newton_method.rs @@ -85,7 +85,7 @@ where } } -impl Solver> for Newton +impl Solver> for Newton where O: Gradient + Hessian, P: Clone + ArgminScaledSub, @@ -97,8 +97,8 @@ where fn next_iter( &mut self, problem: &mut Problem, - mut state: IterState, - ) -> Result<(IterState, Option), Error> { + mut state: IterState, + ) -> Result<(IterState, Option), Error> { let param = state.take_param().ok_or_else(argmin_error_closure!( NotInitialized, concat!( diff --git a/argmin/src/solver/quasinewton/bfgs.rs b/argmin/src/solver/quasinewton/bfgs.rs index cdf668dff..c1f549c3f 100644 --- a/argmin/src/solver/quasinewton/bfgs.rs +++ b/argmin/src/solver/quasinewton/bfgs.rs @@ -398,7 +398,7 @@ mod tests { let mut bfgs: BFGS<_, f64> = BFGS::new(linesearch); // Forgot to initialize the parameter vector - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = bfgs.init(&mut Problem::new(problem), state); assert_error!( @@ -411,7 +411,7 @@ mod tests { ); // Forgot initial inverse Hessian guess - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let res = bfgs.init(&mut Problem::new(problem), state); @@ -426,7 +426,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param.clone()) .inv_hessian(inv_hessian.clone()); let problem = TestProblem::new(); @@ -468,7 +468,7 @@ mod tests { let mut bfgs: BFGS<_, f64> = BFGS::new(linesearch); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .inv_hessian(inv_hessian) .cost(1234.0); @@ -491,7 +491,7 @@ mod tests { let mut bfgs: BFGS<_, f64> = BFGS::new(linesearch); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .inv_hessian(inv_hessian) .gradient(gradient.clone()); diff --git a/argmin/src/solver/quasinewton/dfp.rs b/argmin/src/solver/quasinewton/dfp.rs index f7352be31..f4d6ffa57 100644 --- a/argmin/src/solver/quasinewton/dfp.rs +++ b/argmin/src/solver/quasinewton/dfp.rs @@ -316,7 +316,7 @@ mod tests { let mut dfp: DFP<_, f64> = DFP::new(linesearch); // Forgot to initialize the parameter vector - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = dfp.init(&mut Problem::new(problem), state); assert_error!( @@ -329,7 +329,7 @@ mod tests { ); // Forgot initial inverse Hessian guess - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let res = dfp.init(&mut Problem::new(problem), state); @@ -344,7 +344,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param.clone()) .inv_hessian(inv_hessian.clone()); let problem = TestProblem::new(); @@ -386,7 +386,7 @@ mod tests { let mut dfp: DFP<_, f64> = DFP::new(linesearch); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .inv_hessian(inv_hessian) .cost(1234.0); @@ -409,7 +409,7 @@ mod tests { let mut dfp: DFP<_, f64> = DFP::new(linesearch); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .inv_hessian(inv_hessian) .gradient(gradient.clone()); diff --git a/argmin/src/solver/quasinewton/lbfgs.rs b/argmin/src/solver/quasinewton/lbfgs.rs index 242dc20ea..7596db603 100644 --- a/argmin/src/solver/quasinewton/lbfgs.rs +++ b/argmin/src/solver/quasinewton/lbfgs.rs @@ -611,7 +611,7 @@ mod tests { let mut lbfgs: LBFGS<_, Vec, Vec, f64> = LBFGS::new(linesearch, 3); // Forgot to initialize the parameter vector - let state: IterState, Vec, (), (), f64> = IterState::new(); + let state: IterState, Vec, (), (), (), f64> = IterState::new(); let problem = TestProblem::new(); let res = lbfgs.init(&mut Problem::new(problem), state); assert_error!( @@ -624,7 +624,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), (), f64> = + let state: IterState, Vec, (), (), (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let (mut state_out, kv) = lbfgs.init(&mut Problem::new(problem), state).unwrap(); @@ -654,7 +654,7 @@ mod tests { let mut lbfgs: LBFGS<_, Vec, Vec, f64> = LBFGS::new(linesearch, 3); - let state: IterState, Vec, (), (), f64> = + let state: IterState, Vec, (), (), (), f64> = IterState::new().param(param).cost(1234.0); let problem = TestProblem::new(); @@ -674,7 +674,7 @@ mod tests { let mut lbfgs: LBFGS<_, Vec, Vec, f64> = LBFGS::new(linesearch, 3); - let state: IterState, Vec, (), (), f64> = + let state: IterState, Vec, (), (), (), f64> = IterState::new().param(param).gradient(gradient.clone()); let problem = TestProblem::new(); diff --git a/argmin/src/solver/quasinewton/sr1.rs b/argmin/src/solver/quasinewton/sr1.rs index 65b5a1504..0049d53e3 100644 --- a/argmin/src/solver/quasinewton/sr1.rs +++ b/argmin/src/solver/quasinewton/sr1.rs @@ -424,7 +424,7 @@ mod tests { let mut sr1: SR1<_, f64> = SR1::new(linesearch); // Forgot to initialize the parameter vector - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = sr1.init(&mut Problem::new(problem), state); assert_error!( @@ -437,7 +437,7 @@ mod tests { ); // Forgot initial inverse Hessian guess - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let res = sr1.init(&mut Problem::new(problem), state); @@ -452,7 +452,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param.clone()) .inv_hessian(inv_hessian.clone()); let problem = TestProblem::new(); @@ -494,7 +494,7 @@ mod tests { let mut sr1: SR1<_, f64> = SR1::new(linesearch); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .inv_hessian(inv_hessian) .cost(1234.0); @@ -517,7 +517,7 @@ mod tests { let mut sr1: SR1<_, f64> = SR1::new(linesearch); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .inv_hessian(inv_hessian) .gradient(gradient.clone()); diff --git a/argmin/src/solver/quasinewton/sr1_trustregion.rs b/argmin/src/solver/quasinewton/sr1_trustregion.rs index 63432451a..09b92cc70 100644 --- a/argmin/src/solver/quasinewton/sr1_trustregion.rs +++ b/argmin/src/solver/quasinewton/sr1_trustregion.rs @@ -452,7 +452,7 @@ mod tests { let mut sr1: SR1TrustRegion<_, f64> = SR1TrustRegion::new(subproblem); // Forgot to initialize the parameter vector - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = sr1.init(&mut Problem::new(problem), state); assert_error!( @@ -465,7 +465,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let (mut state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap(); @@ -495,7 +495,7 @@ mod tests { let mut sr1: SR1TrustRegion<_, f64> = SR1TrustRegion::new(subproblem); - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param).cost(1234.0); let problem = TestProblem::new(); @@ -515,7 +515,7 @@ mod tests { let mut sr1: SR1TrustRegion<_, f64> = SR1TrustRegion::new(subproblem); - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param).gradient(gradient.clone()); let problem = TestProblem::new(); @@ -540,7 +540,7 @@ mod tests { let mut sr1: SR1TrustRegion<_, f64> = SR1TrustRegion::new(subproblem); - let state: IterState, Vec, (), Vec>, f64> = IterState::new() + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new() .param(param) .gradient(gradient) .hessian(hessian.clone()); diff --git a/argmin/src/solver/simulatedannealing/mod.rs b/argmin/src/solver/simulatedannealing/mod.rs index 62d05f7b2..d677739e5 100644 --- a/argmin/src/solver/simulatedannealing/mod.rs +++ b/argmin/src/solver/simulatedannealing/mod.rs @@ -843,7 +843,7 @@ mod tests { .with_reannealing_best(reanneal_best); // Forgot to initialize the parameter vector - let state: IterState, (), (), (), f64> = IterState::new(); + let state: IterState, (), (), (), (), f64> = IterState::new(); let problem = TestProblem::new(); let res = sa.init(&mut Problem::new(problem), state); assert_error!( @@ -856,7 +856,7 @@ mod tests { ); // All good. - let state: IterState, (), (), (), f64> = IterState::new().param(param.clone()); + let state: IterState, (), (), (), (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let (mut state_out, kv) = sa.init(&mut Problem::new(problem), state).unwrap(); diff --git a/argmin/src/solver/trustregion/cauchypoint.rs b/argmin/src/solver/trustregion/cauchypoint.rs index a852afa56..b46a9d577 100644 --- a/argmin/src/solver/trustregion/cauchypoint.rs +++ b/argmin/src/solver/trustregion/cauchypoint.rs @@ -153,7 +153,7 @@ mod tests { cp.set_radius(1.0); // Forgot to initialize the parameter vector - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = cp.next_iter(&mut Problem::new(problem), state); assert_error!( @@ -166,7 +166,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param); let problem = TestProblem::new(); let (mut state_out, kv) = cp.next_iter(&mut Problem::new(problem), state).unwrap(); diff --git a/argmin/src/solver/trustregion/dogleg.rs b/argmin/src/solver/trustregion/dogleg.rs index 8c11f3d25..5305eac83 100644 --- a/argmin/src/solver/trustregion/dogleg.rs +++ b/argmin/src/solver/trustregion/dogleg.rs @@ -212,7 +212,7 @@ mod tests { dl.set_radius(1.0); // Forgot to initialize the parameter vector - let state: IterState, Array1, (), Array2, f64> = IterState::new(); + let state: IterState, Array1, (), Array2, (), f64> = IterState::new(); let problem = TestProblem {}; let res = dl.next_iter(&mut Problem::new(problem), state); assert_error!( @@ -225,7 +225,7 @@ mod tests { ); // All good. - let state: IterState, Array1, (), Array2, f64> = + let state: IterState, Array1, (), Array2, (), f64> = IterState::new().param(param); let problem = TestProblem {}; let (mut state_out, kv) = dl.next_iter(&mut Problem::new(problem), state).unwrap(); diff --git a/argmin/src/solver/trustregion/steihaug.rs b/argmin/src/solver/trustregion/steihaug.rs index 59f1488aa..512d9eaf4 100644 --- a/argmin/src/solver/trustregion/steihaug.rs +++ b/argmin/src/solver/trustregion/steihaug.rs @@ -404,7 +404,7 @@ mod tests { sh.set_radius(1.0); // Forgot to initialize gradient - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = sh.init(&mut Problem::new(problem), state); assert_error!( @@ -417,7 +417,7 @@ mod tests { ); // Forgot to initialize Hessian - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().gradient(grad.clone()); let problem = TestProblem::new(); let res = sh.init(&mut Problem::new(problem), state); @@ -431,7 +431,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().gradient(grad.clone()).hessian(hessian); let problem = TestProblem::new(); let (mut state_out, kv) = sh.init(&mut Problem::new(problem), state).unwrap(); diff --git a/argmin/src/solver/trustregion/trustregion_method.rs b/argmin/src/solver/trustregion/trustregion_method.rs index 6f73cdac1..03e776d47 100644 --- a/argmin/src/solver/trustregion/trustregion_method.rs +++ b/argmin/src/solver/trustregion/trustregion_method.rs @@ -402,7 +402,7 @@ mod tests { let mut tr: TrustRegion<_, f64> = TrustRegion::new(cp); // Forgot to initialize parameter vector - let state: IterState, Vec, (), Vec>, f64> = IterState::new(); + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new(); let problem = TestProblem::new(); let res = tr.init(&mut Problem::new(problem), state); assert_error!( @@ -415,7 +415,7 @@ mod tests { ); // All good. - let state: IterState, Vec, (), Vec>, f64> = + let state: IterState, Vec, (), Vec>, (), f64> = IterState::new().param(param.clone()); let problem = TestProblem::new(); let (mut state_out, kv) = tr.init(&mut Problem::new(problem), state).unwrap(); diff --git a/media/book/src/implementing_solver.md b/media/book/src/implementing_solver.md index 7291eeea7..8255c48b0 100644 --- a/media/book/src/implementing_solver.md +++ b/media/book/src/implementing_solver.md @@ -85,7 +85,7 @@ impl Landweber { } } -impl Solver> for Landweber +impl Solver> for Landweber where // The Landweber solver requires `O` to implement `Gradient`. // `P` and `G` indicate the types of the parameter vector and gradient, @@ -111,8 +111,8 @@ where // vector, gradient, Hessian and cost function value of the current, // previous and best iteration as well as current iteration number, and // many more. - mut state: IterState, - ) -> Result<(IterState, Option), Error> { + mut state: IterState, + ) -> Result<(IterState, Option), Error> { // First we obtain the current parameter vector from the `state` struct (`x_k`). // Landweber requires an initial parameter vector. Return an error if this was // not provided by the user.