Skip to content

Commit

Permalink
Added residuals handling to GaussNewton, fixed residuals related issu…
Browse files Browse the repository at this point in the history
…es throughout the codebase
  • Loading branch information
stefan-k committed Jan 17, 2024
1 parent c3180ca commit 7af653f
Show file tree
Hide file tree
Showing 27 changed files with 246 additions and 158 deletions.
5 changes: 3 additions & 2 deletions argmin/src/core/checkpointing/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,12 @@ mod tests {
#[allow(clippy::type_complexity)]
fn test_save() {
let solver = TestSolver::new();
let state: IterState<Vec<f64>, (), (), (), f64> = IterState::new().param(vec![1.0f64, 0.0]);
let state: IterState<Vec<f64>, (), (), (), (), f64> =
IterState::new().param(vec![1.0f64, 0.0]);
let check = FileCheckpoint::new("checkpoints", "solver", CheckpointingFrequency::Always);
check.save_cond(&solver, &state, 20).unwrap();

let _loaded: Option<(TestSolver, IterState<Vec<f64>, (), (), (), f64>)> =
let _loaded: Option<(TestSolver, IterState<Vec<f64>, (), (), (), (), f64>)> =
check.load().unwrap();
}
}
24 changes: 13 additions & 11 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,9 @@ mod tests {
let problem = TestProblem::new();
let solver = TestSolver::new();

let mut executor = Executor::new(problem, solver)
.configure(|config: IterState<Vec<f64>, (), (), (), f64>| config.param(vec![0.0, 0.0]));
let mut executor = Executor::new(problem, solver).configure(
|config: IterState<Vec<f64>, (), (), (), (), f64>| config.param(vec![0.0, 0.0]),
);

// 1) Parameter vector changes, but not cost (continues to be `Inf`)
let new_param = vec![1.0, 1.0];
Expand Down Expand Up @@ -492,8 +493,9 @@ mod tests {

// 4) `-Inf` is better than `Inf`
let solver = TestSolver {};
let mut executor = Executor::new(problem, solver)
.configure(|config: IterState<Vec<f64>, (), (), (), f64>| config.param(vec![0.0, 0.0]));
let mut executor = Executor::new(problem, solver).configure(
|config: IterState<Vec<f64>, (), (), (), (), f64>| config.param(vec![0.0, 0.0]),
);

let new_param = vec![1.0, 1.0];
let new_cost = std::f64::NEG_INFINITY;
Expand Down Expand Up @@ -584,7 +586,7 @@ mod tests {
}

// Implement Solver for OptimizationAlgorithm
impl<O, P, F> Solver<O, IterState<P, (), (), (), F>> for OptimizationAlgorithm
impl<O, P, F> Solver<O, IterState<P, (), (), (), (), F>> for OptimizationAlgorithm
where
O: CostFunction<Param = P, Output = F>,
P: Clone,
Expand All @@ -596,8 +598,8 @@ mod tests {
fn init(
&mut self,
_problem: &mut Problem<O>,
state: IterState<P, (), (), (), F>,
) -> Result<(IterState<P, (), (), (), F>, Option<KV>), Error> {
state: IterState<P, (), (), (), (), F>,
) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
self.internal_state = 1;
Ok((state, None))
}
Expand All @@ -606,21 +608,21 @@ mod tests {
fn next_iter(
&mut self,
_problem: &mut Problem<O>,
state: IterState<P, (), (), (), F>,
) -> Result<(IterState<P, (), (), (), F>, Option<KV>), Error> {
state: IterState<P, (), (), (), (), F>,
) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
self.internal_state += 1;
Ok((state, None))
}

// Avoid terminating early because param does not change
fn terminate(&mut self, _state: &IterState<P, (), (), (), F>) -> TerminationStatus {
fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
TerminationStatus::NotTerminated
}

// Avoid terminating early because param does not change
fn terminate_internal(
&mut self,
state: &IterState<P, (), (), (), F>,
state: &IterState<P, (), (), (), (), F>,
) -> TerminationStatus {
if state.get_iter() >= state.get_max_iters() {
TerminationStatus::Terminated(TerminationReason::MaxItersReached)
Expand Down
8 changes: 4 additions & 4 deletions argmin/src/core/observers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl<I> Observers<I> {
/// use argmin::core::observers::Observers;
/// use argmin::core::IterState;
///
/// let observers: Observers<IterState<Vec<f64>, (), (), (), f64>> = Observers::new();
/// let observers: Observers<IterState<Vec<f64>, (), (), (), (), f64>> = Observers::new();
/// # assert!(observers.is_empty());
/// ```
pub fn new() -> Self {
Expand All @@ -214,7 +214,7 @@ impl<I> Observers<I> {
/// use argmin_observer_slog::SlogLogger;
/// use argmin::core::IterState;
///
/// let mut observers: Observers<IterState<Vec<f64>, (), (), (), f64>> = Observers::new();
/// let mut observers: Observers<IterState<Vec<f64>, (), (), (), (), f64>> = Observers::new();
///
/// let logger = SlogLogger::term();
/// observers.push(logger, ObserverMode::Always);
Expand All @@ -237,7 +237,7 @@ impl<I> Observers<I> {
/// use argmin::core::observers::Observers;
/// use argmin::core::IterState;
///
/// let observers: Observers<IterState<Vec<f64>, (), (), (), f64>> = Observers::new();
/// let observers: Observers<IterState<Vec<f64>, (), (), (), (), f64>> = Observers::new();
/// assert!(observers.is_empty());
/// ```
pub fn is_empty(&self) -> bool {
Expand Down Expand Up @@ -375,7 +375,7 @@ mod tests {

let storages = [test_stor_1, test_stor_2, test_stor_3, test_stor_4];

type TState = IterState<Vec<f64>, (), (), (), f64>;
type TState = IterState<Vec<f64>, (), (), (), (), f64>;

let mut obs: Observers<TState> = Observers::new();
obs.push(test_obs_1, ObserverMode::Never)
Expand Down
12 changes: 6 additions & 6 deletions argmin/src/core/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<O, S, I> OptimizationResult<O, S, I> {
/// # struct SomeSolver {}
/// #
/// let rosenbrock = Rosenbrock::new();
/// let state: IterState<Vec<f64>, (), (), (), f64> = IterState::new();
/// let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
/// let solver = SomeSolver {};
///
/// let result = OptimizationResult::new(Problem::new(rosenbrock), solver, state);
Expand All @@ -65,7 +65,7 @@ impl<O, S, I> OptimizationResult<O, S, I> {
/// # struct Rosenbrock {}
/// # let solver = ();
/// #
/// # let state: IterState<Vec<f64>, (), (), (), f64> = IterState::new();
/// # let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
/// #
/// # let result = OptimizationResult::new(Problem::new(Rosenbrock {}), solver, state);
/// #
Expand All @@ -85,7 +85,7 @@ impl<O, S, I> OptimizationResult<O, S, I> {
/// # struct Rosenbrock {}
/// # let solver = ();
/// #
/// # let state: IterState<Vec<f64>, (), (), (), f64> = IterState::new();
/// # let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
/// #
/// # let result = OptimizationResult::new(Problem::new(Rosenbrock {}), solver, state);
/// #
Expand All @@ -105,11 +105,11 @@ impl<O, S, I> OptimizationResult<O, S, I> {
/// # struct Rosenbrock {}
/// # let solver = ();
/// #
/// # let state: IterState<Vec<f64>, (), (), (), f64> = IterState::new();
/// # let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
/// #
/// # let result = OptimizationResult::new(Problem::new(Rosenbrock {}), solver, state);
/// #
/// let state: &IterState<Vec<f64>, (), (), (), f64> = result.state();
/// let state: &IterState<Vec<f64>, (), (), (), (), f64> = result.state();
/// ```
pub fn state(&self) -> &I {
&self.state
Expand Down Expand Up @@ -199,7 +199,7 @@ mod tests {

send_sync_test!(
optimizationresult,
OptimizationResult<TestProblem, TestSolver, IterState<(), (), (), (), f64>>
OptimizationResult<TestProblem, TestSolver, IterState<(), (), (), (), (), f64>>
);

// TODO: More tests, in particular the checking that the output is as intended.
Expand Down
12 changes: 6 additions & 6 deletions argmin/src/core/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
/// struct OptimizationAlgorithm {}
///
/// impl<O, P, G, J, H, F> Solver<O, IterState<P, G, J, H, F>> for OptimizationAlgorithm
/// impl<O, P, G, J, H, R, F> Solver<O, IterState<P, G, J, H, R, F>> for OptimizationAlgorithm
/// where
/// O: CostFunction<Param = P, Output = F>,
/// P: Clone,
Expand All @@ -41,8 +41,8 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// fn init(
/// &mut self,
/// problem: &mut Problem<O>,
/// state: IterState<P, G, J, H, F>,
/// ) -> Result<(IterState<P, G, J, H, F>, Option<KV>), Error> {
/// state: IterState<P, G, J, H, R, F>,
/// ) -> Result<(IterState<P, G, J, H, R, F>, Option<KV>), Error> {
/// // Initialize algorithm, update `state`.
/// // Implementing this method is optional.
/// Ok((state, None))
Expand All @@ -51,14 +51,14 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// fn next_iter(
/// &mut self,
/// problem: &mut Problem<O>,
/// state: IterState<P, G, J, H, F>,
/// ) -> Result<(IterState<P, G, J, H, F>, Option<KV>), Error> {
/// state: IterState<P, G, J, H, R, F>,
/// ) -> Result<(IterState<P, G, J, H, R, F>, Option<KV>), Error> {
/// // Compute single iteration of algorithm, update `state`.
/// // Implementing this method is required.
/// Ok((state, None))
/// }
///
/// fn terminate(&mut self, state: &IterState<P, G, J, H, F>) -> TerminationStatus {
/// fn terminate(&mut self, state: &IterState<P, G, J, H, R, F>) -> TerminationStatus {
/// // Check if stopping criteria are met.
/// // Implementing this method is optional.
/// TerminationStatus::NotTerminated
Expand Down
Loading

0 comments on commit 7af653f

Please sign in to comment.