From e18ac71772e50203bd0dc72d4bdf42375c89c88b Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Thu, 18 Jan 2024 10:21:03 +0100 Subject: [PATCH 1/2] Dedicated crate for `FileCheckpoint`, removed `SerializeAlias` and `DeserializeOwnedAlias` --- Cargo.toml | 1 + argmin/Cargo.toml | 4 +- argmin/src/core/checkpointing/mod.rs | 14 ++-- argmin/src/core/executor.rs | 74 +++++++++++++++---- argmin/src/core/float.rs | 6 +- argmin/src/core/mod.rs | 3 - argmin/src/core/serialization.rs | 53 ------------- argmin/src/core/solver.rs | 2 +- argmin/src/solver/conjugategradient/beta.rs | 4 +- argmin/src/solver/conjugategradient/cg.rs | 5 +- .../solver/conjugategradient/nonlinear_cg.rs | 13 +--- .../gaussnewton/gaussnewton_linesearch.rs | 17 ++--- .../solver/gradientdescent/steepestdescent.rs | 8 +- argmin/src/solver/linesearch/backtracking.rs | 10 +-- argmin/src/solver/linesearch/hagerzhang.rs | 8 +- argmin/src/solver/linesearch/morethuente.rs | 8 +- argmin/src/solver/neldermead/mod.rs | 6 +- argmin/src/solver/newton/newton_cg.rs | 11 +-- argmin/src/solver/particleswarm/mod.rs | 6 +- argmin/src/solver/quasinewton/bfgs.rs | 24 ++---- argmin/src/solver/quasinewton/dfp.rs | 28 ++----- argmin/src/solver/quasinewton/lbfgs.rs | 11 +-- argmin/src/solver/quasinewton/sr1.rs | 21 +----- .../src/solver/quasinewton/sr1_trustregion.rs | 21 +----- argmin/src/solver/simulatedannealing/mod.rs | 6 +- argmin/src/solver/trustregion/steihaug.rs | 5 +- .../solver/trustregion/trustregion_method.rs | 18 ++--- checkpointing/file/.gitignore | 1 + checkpointing/file/Cargo.toml | 19 +++++ .../file.rs => checkpointing/file/src/lib.rs | 25 ++++--- examples/checkpoint/.gitignore | 1 + examples/checkpoint/Cargo.toml | 1 + examples/checkpoint/src/main.rs | 6 +- 33 files changed, 180 insertions(+), 260 deletions(-) delete mode 100644 argmin/src/core/serialization.rs create mode 100644 checkpointing/file/.gitignore create mode 100644 checkpointing/file/Cargo.toml rename argmin/src/core/checkpointing/file.rs => checkpointing/file/src/lib.rs (89%) create mode 100644 examples/checkpoint/.gitignore diff --git a/Cargo.toml b/Cargo.toml index bd4ca1886..6d8239dcb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ resolver = "2" members = [ "argmin", "argmin-math", + "checkpointing/*", "observers/*", "tools/spectator", diff --git a/argmin/Cargo.toml b/argmin/Cargo.toml index aff79c911..9689846b9 100644 --- a/argmin/Cargo.toml +++ b/argmin/Cargo.toml @@ -24,7 +24,6 @@ rand_xoshiro = "0.6.0" thiserror = "1.0" argmin-math = { path = "../argmin-math", version = "0.3", default-features = false, features = ["primitives"] } # optional -bincode = { version = "1.3.3", optional = true } ctrlc = { version = "3.2.4", optional = true } getrandom = { version = "0.2", optional = true } rayon = { version = "1.6.0", optional = true } @@ -39,11 +38,12 @@ ndarray-linalg = { version = "0.16", features = ["intel-mkl-static"] } argmin-math = { path = "../argmin-math", version = "0.3", features = ["vec"] } argmin-observer-slog = { path = "../observers/slog" } argmin-observer-paramwriter = { path = "../observers/paramwriter" } +argmin-checkpointing-file = { path = "../checkpointing/file" } [features] default = [] wasm-bindgen = ["instant/wasm-bindgen", "getrandom/js"] -serde1 = ["serde", "bincode", "rand_xoshiro/serde1"] +serde1 = ["serde", "rand_xoshiro/serde1"] _ndarrayl = ["argmin-math/ndarray_latest"] # When adding new features, please consider adding them to either `full` (for users) # or `_full_dev` (only for local development, testing and computing test coverage). diff --git a/argmin/src/core/checkpointing/mod.rs b/argmin/src/core/checkpointing/mod.rs index 51d369e8e..3cd134910 100644 --- a/argmin/src/core/checkpointing/mod.rs +++ b/argmin/src/core/checkpointing/mod.rs @@ -12,7 +12,8 @@ //! with a user-chosen frequency. Optimizations can then be resumed from a given checkpoint after a //! crash. //! -//! For saving checkpoints to disk, `FileCheckpoint` is provided. +//! For saving checkpoints to disk, `FileCheckpoint` is provided in the `argmin-checkpointing-file` +//! crate. //! Via the `Checkpoint` trait other checkpointing approaches can be implemented. //! //! The `CheckpointingFrequency` defines how often checkpoints are saved and can be chosen to be @@ -29,7 +30,9 @@ //! # extern crate argmin_testfunctions; //! # use argmin::core::{CostFunction, Error, Executor, Gradient, observers::ObserverMode}; //! # #[cfg(feature = "serde1")] -//! # use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency}; +//! use argmin::core::checkpointing::CheckpointingFrequency; +//! # #[cfg(feature = "serde1")] +//! use argmin_checkpointing_file::FileCheckpoint; //! # use argmin_observer_slog::SlogLogger; //! # use argmin::solver::landweber::Landweber; //! # use argmin_testfunctions::{rosenbrock_2d, rosenbrock_2d_derivative}; @@ -70,6 +73,7 @@ //! # //! # let iters = 35; //! # let solver = Landweber::new(0.001); +//! //! // [...] //! //! # #[cfg(feature = "serde1")] @@ -98,12 +102,6 @@ //! # } //! ``` -#[cfg(feature = "serde1")] -mod file; - -#[cfg(feature = "serde1")] -pub use crate::core::checkpointing::file::FileCheckpoint; - use crate::core::Error; use std::default::Default; use std::fmt::Display; diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index 843b80713..a796f2c4d 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -8,8 +8,7 @@ use crate::core::checkpointing::Checkpoint; use crate::core::observers::{Observe, ObserverMode, Observers}; use crate::core::{ - DeserializeOwnedAlias, Error, OptimizationResult, Problem, SerializeAlias, Solver, State, - TerminationReason, TerminationStatus, KV, + Error, OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV, }; use instant; use std::sync::atomic::{AtomicBool, Ordering}; @@ -36,7 +35,7 @@ pub struct Executor { impl Executor where S: Solver, - I: State + SerializeAlias + DeserializeOwnedAlias, + I: State, { /// Constructs an `Executor` from a user defined problem and a solver. /// @@ -313,7 +312,8 @@ where /// ``` /// # use argmin::core::{Error, Executor}; /// # #[cfg(feature = "serde1")] - /// # use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency}; + /// # use argmin::core::checkpointing::CheckpointingFrequency; + /// # use argmin_checkpointing_file::FileCheckpoint; /// # use argmin::core::test_utils::{TestSolver, TestProblem}; /// # /// # fn main() -> Result<(), Error> { @@ -570,17 +570,63 @@ mod tests { /// The solver's `init` should not be called when started from a checkpoint. /// See https://github.com/argmin-rs/argmin/issues/199. + // #[cfg(feature = "serde1")] #[test] #[cfg(feature = "serde1")] fn test_checkpointing_solver_initialization() { - use crate::core::checkpointing::{CheckpointingFrequency, FileCheckpoint}; - use crate::core::test_utils::TestProblem; - use crate::core::{ArgminFloat, CostFunction}; + use std::cell::RefCell; + + use crate::core::{ + checkpointing::CheckpointingFrequency, test_utils::TestProblem, ArgminFloat, + CostFunction, + }; use serde::{Deserialize, Serialize}; - // Fake optimization algorithm which holds internal state which changes over time #[derive(Clone)] - #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] + pub struct FakeCheckpoint { + pub frequency: CheckpointingFrequency, + pub solver: RefCell>, + pub state: RefCell, (), (), (), (), f64>>>, + } + + impl Checkpoint, (), (), (), (), f64>> + for FakeCheckpoint + { + fn save( + &self, + solver: &OptimizationAlgorithm, + state: &IterState, (), (), (), (), f64>, + ) -> Result<(), Error> { + *self.solver.borrow_mut() = Some(solver.clone()); + *self.state.borrow_mut() = Some(state.clone()); + Ok(()) + } + + fn load( + &self, + ) -> Result< + Option<( + OptimizationAlgorithm, + IterState, (), (), (), (), f64>, + )>, + Error, + > { + if self.solver.borrow().is_none() { + return Ok(None); + } + Ok(Some(( + self.solver.borrow().clone().unwrap(), + self.state.borrow().clone().unwrap(), + ))) + } + + fn frequency(&self) -> CheckpointingFrequency { + self.frequency + } + } + + // Fake optimization algorithm which holds internal state which changes over time + #[derive(Clone, Serialize, Deserialize)] struct OptimizationAlgorithm { pub internal_state: u64, } @@ -638,12 +684,12 @@ mod tests { // solver instance let solver = OptimizationAlgorithm { internal_state: 0 }; - // Delete old checkpointing file - let _ = std::fs::remove_file(".checkpoints/init_test.arg"); - // Create a checkpoint - let checkpoint = - FileCheckpoint::new(".checkpoints", "init_test", CheckpointingFrequency::Always); + let checkpoint = FakeCheckpoint { + frequency: CheckpointingFrequency::Always, + solver: RefCell::new(None), + state: RefCell::new(None), + }; // Create and run executor let executor = Executor::new(problem, solver) diff --git a/argmin/src/core/float.rs b/argmin/src/core/float.rs index 3e98ad555..a2aa24750 100644 --- a/argmin/src/core/float.rs +++ b/argmin/src/core/float.rs @@ -5,7 +5,7 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use crate::core::{kv::KvValue, DeserializeOwnedAlias, SendAlias, SerializeAlias}; +use crate::core::{kv::KvValue, SendAlias}; use num_traits::{Float, FloatConst, FromPrimitive, ToPrimitive}; use std::fmt::{Debug, Display}; @@ -21,8 +21,6 @@ pub trait ArgminFloat: + ToPrimitive + Debug + Display - + SerializeAlias - + DeserializeOwnedAlias + SendAlias + Into { @@ -37,8 +35,6 @@ impl ArgminFloat for I where + ToPrimitive + Debug + Display - + SerializeAlias - + DeserializeOwnedAlias + SendAlias + Into { diff --git a/argmin/src/core/mod.rs b/argmin/src/core/mod.rs index 35918762e..328e186b0 100644 --- a/argmin/src/core/mod.rs +++ b/argmin/src/core/mod.rs @@ -29,8 +29,6 @@ mod parallelization; mod problem; /// Definition of the return type of the solvers mod result; -/// Trait alias for `serde`s `Serialize` and `DeserializeOwned` -mod serialization; /// `Solver` trait mod solver; /// iteration state @@ -51,7 +49,6 @@ pub use kv::{KvValue, KV}; pub use parallelization::{SendAlias, SyncAlias}; pub use problem::{CostFunction, Gradient, Hessian, Jacobian, LinearProgram, Operator, Problem}; pub use result::OptimizationResult; -pub use serialization::{DeserializeOwnedAlias, SerializeAlias}; pub use solver::Solver; pub use state::{IterState, LinearProgramState, PopulationState, State}; pub use termination::{TerminationReason, TerminationStatus}; diff --git a/argmin/src/core/serialization.rs b/argmin/src/core/serialization.rs deleted file mode 100644 index d55922029..000000000 --- a/argmin/src/core/serialization.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2018-2022 argmin developers -// -// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be -// copied, modified, or distributed except according to those terms. - -#[cfg(feature = "serde1")] -use serde::{de::DeserializeOwned, Serialize}; - -/// Trait alias for `serde`s `Serialize`. -/// -/// If the `serde1` feature is set, it acts as an alias for `Serialize` and is implemented for all -/// types which implement `Serialize`. If `serde1` is not set, it will be an "empty" trait -/// implemented for all types. -#[cfg(feature = "serde1")] -pub trait SerializeAlias: Serialize {} - -/// Trait alias for `serde`s `Serialize`. -/// -/// If the `serde1` feature is set, it acts as an alias for `Serialize` and is implemented for all -/// types which implement `Serialize`. If `serde1` is not set, it will be an "empty" trait -/// implemented for all types. -#[cfg(not(feature = "serde1"))] -pub trait SerializeAlias {} - -#[cfg(feature = "serde1")] -impl SerializeAlias for T where T: Serialize {} - -#[cfg(not(feature = "serde1"))] -impl SerializeAlias for T {} - -/// Trait alias for `serde`s `DeserializeOwned`. -/// -/// If the `serde1` feature is set, it acts as an alias for `DeserializeOwned` and is implemented -/// for all types which implement `DeserializeOwned`. If `serde1` is not set, it will be an "empty" -/// trait implemented for all types. -#[cfg(feature = "serde1")] -pub trait DeserializeOwnedAlias: DeserializeOwned {} - -/// Trait alias for `serde`s `DeserializeOwned`. -/// -/// If the `serde1` feature is set, it acts as an alias for `DeserializeOwned` and is implemented -/// for all types which implement `DeserializeOwned`. If `serde1` is not set, it will be an "empty" -/// trait implemented for all types. -#[cfg(not(feature = "serde1"))] -pub trait DeserializeOwnedAlias {} - -#[cfg(feature = "serde1")] -impl DeserializeOwnedAlias for T where T: DeserializeOwned {} - -#[cfg(not(feature = "serde1"))] -impl DeserializeOwnedAlias for T {} diff --git a/argmin/src/core/solver.rs b/argmin/src/core/solver.rs index a09915830..dc85df231 100644 --- a/argmin/src/core/solver.rs +++ b/argmin/src/core/solver.rs @@ -15,7 +15,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// ([`terminate`](`Solver::terminate`) and [`terminate_internal`](`Solver::terminate_internal`)). /// Only `next_iter` is mandatory to implement, all others provide default implementations. /// -/// A `Solver` needs to be serializable. +/// A `Solver` should be (de)serializable in order to work with checkpointing. /// /// # Example /// diff --git a/argmin/src/solver/conjugategradient/beta.rs b/argmin/src/solver/conjugategradient/beta.rs index b3cd3ef99..0fe573564 100644 --- a/argmin/src/solver/conjugategradient/beta.rs +++ b/argmin/src/solver/conjugategradient/beta.rs @@ -17,7 +17,7 @@ //! \[0\] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization. //! Springer. ISBN 0-387-30303-0. -use crate::core::{ArgminFloat, SerializeAlias}; +use crate::core::ArgminFloat; use argmin_math::{ArgminDot, ArgminL2Norm, ArgminSub}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -44,7 +44,7 @@ use serde::{Deserialize, Serialize}; /// } /// } /// ``` -pub trait NLCGBetaUpdate: SerializeAlias { +pub trait NLCGBetaUpdate { /// Update beta. /// /// # Parameters diff --git a/argmin/src/solver/conjugategradient/cg.rs b/argmin/src/solver/conjugategradient/cg.rs index 2822c6666..59bbdf9db 100644 --- a/argmin/src/solver/conjugategradient/cg.rs +++ b/argmin/src/solver/conjugategradient/cg.rs @@ -5,9 +5,7 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use crate::core::{ - ArgminFloat, Error, IterState, Operator, Problem, SerializeAlias, Solver, State, KV, -}; +use crate::core::{ArgminFloat, Error, IterState, Operator, Problem, Solver, State, KV}; use argmin_math::{ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -93,7 +91,6 @@ impl Solver> for ConjugateGradient

, P: Clone - + SerializeAlias + ArgminDot + ArgminSub + ArgminScaledAdd diff --git a/argmin/src/solver/conjugategradient/nonlinear_cg.rs b/argmin/src/solver/conjugategradient/nonlinear_cg.rs index c0bc5f686..b1cd80e43 100644 --- a/argmin/src/solver/conjugategradient/nonlinear_cg.rs +++ b/argmin/src/solver/conjugategradient/nonlinear_cg.rs @@ -6,8 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, NLCGBetaUpdate, OptimizationResult, Problem, SerializeAlias, Solver, State, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, NLCGBetaUpdate, + OptimizationResult, Problem, Solver, State, KV, }; use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul}; #[cfg(feature = "serde1")] @@ -122,13 +122,8 @@ impl Solver> for NonlinearConjugateGradient where O: CostFunction + Gradient, - P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminAdd + ArgminMul, - G: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminMul - + ArgminDot - + ArgminL2Norm, + P: Clone + ArgminAdd + ArgminMul, + G: Clone + ArgminMul + ArgminDot + ArgminL2Norm, L: Clone + LineSearch + Solver>, B: NLCGBetaUpdate, F: ArgminFloat, diff --git a/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs b/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs index a119a329f..4b49e1ed0 100644 --- a/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs +++ b/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - Jacobian, LineSearch, Operator, OptimizationResult, Problem, SerializeAlias, Solver, - TerminationReason, TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, Jacobian, LineSearch, + Operator, OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV, }; use argmin_math::{ArgminDot, ArgminInv, ArgminL2Norm, ArgminMul, ArgminTranspose}; #[cfg(feature = "serde1")] @@ -84,12 +83,10 @@ impl GaussNewtonLS { impl Solver> for GaussNewtonLS where O: Operator + Jacobian, - P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul, - G: Clone + SerializeAlias + DeserializeOwnedAlias, + P: Clone + ArgminMul, + G: Clone, U: ArgminL2Norm, J: Clone - + SerializeAlias - + DeserializeOwnedAlias + ArgminTranspose + ArgminInv + ArgminDot @@ -97,7 +94,7 @@ where + ArgminDot, L: Clone + LineSearch + Solver, IterState>, F: ArgminFloat, - R: Clone + SerializeAlias + DeserializeOwnedAlias, + R: Clone, { const NAME: &'static str = "Gauss-Newton method with line search"; @@ -196,7 +193,7 @@ impl LineSearchProblem { impl CostFunction for LineSearchProblem where O: Operator, - P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminL2Norm, + P: Clone + ArgminL2Norm, F: ArgminFloat, { type Param = P; @@ -210,7 +207,7 @@ where impl Gradient for LineSearchProblem where O: Operator + Jacobian, - P: Clone + SerializeAlias + DeserializeOwnedAlias, + P: Clone, J: ArgminTranspose + ArgminDot, { type Param = P; diff --git a/argmin/src/solver/gradientdescent/steepestdescent.rs b/argmin/src/solver/gradientdescent/steepestdescent.rs index 06f90952e..939ac0668 100644 --- a/argmin/src/solver/gradientdescent/steepestdescent.rs +++ b/argmin/src/solver/gradientdescent/steepestdescent.rs @@ -6,8 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, + OptimizationResult, Problem, Solver, State, KV, }; use argmin_math::ArgminMul; #[cfg(feature = "serde1")] @@ -53,8 +53,8 @@ impl SteepestDescent { impl Solver> for SteepestDescent where O: CostFunction + Gradient, - P: Clone + SerializeAlias + DeserializeOwnedAlias, - G: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminMul, + P: Clone, + G: Clone + ArgminMul, L: Clone + LineSearch + Solver>, F: ArgminFloat, { diff --git a/argmin/src/solver/linesearch/backtracking.rs b/argmin/src/solver/linesearch/backtracking.rs index 8562adcf2..00ffd6e22 100644 --- a/argmin/src/solver/linesearch/backtracking.rs +++ b/argmin/src/solver/linesearch/backtracking.rs @@ -6,8 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, SerializeAlias, - Solver, State, TerminationReason, TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, State, + TerminationReason, TerminationStatus, KV, }; use crate::solver::linesearch::condition::*; use argmin_math::ArgminScaledAdd; @@ -177,10 +177,10 @@ where impl Solver> for BacktrackingLineSearch where - P: Clone + SerializeAlias + ArgminScaledAdd, - G: SerializeAlias + ArgminScaledAdd, + P: Clone + ArgminScaledAdd, + G: ArgminScaledAdd, O: CostFunction + Gradient, - L: LineSearchCondition + SerializeAlias, + L: LineSearchCondition, F: ArgminFloat, { const NAME: &'static str = "Backtracking line search"; diff --git a/argmin/src/solver/linesearch/hagerzhang.rs b/argmin/src/solver/linesearch/hagerzhang.rs index 0bdce3b4f..6ebd0f3de 100644 --- a/argmin/src/solver/linesearch/hagerzhang.rs +++ b/argmin/src/solver/linesearch/hagerzhang.rs @@ -6,8 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, SerializeAlias, - Solver, TerminationReason, TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, + TerminationReason, TerminationStatus, KV, }; use argmin_math::{ArgminDot, ArgminScaledAdd}; #[cfg(feature = "serde1")] @@ -498,8 +498,8 @@ impl LineSearch for HagerZhangLineSearch { impl Solver> for HagerZhangLineSearch where O: CostFunction + Gradient, - P: Clone + SerializeAlias + ArgminDot + ArgminScaledAdd, - G: Clone + SerializeAlias + ArgminDot, + P: Clone + ArgminDot + ArgminScaledAdd, + G: Clone + ArgminDot, F: ArgminFloat, { const NAME: &'static str = "Hager-Zhang line search"; diff --git a/argmin/src/solver/linesearch/morethuente.rs b/argmin/src/solver/linesearch/morethuente.rs index 5351791f3..8fc0a89af 100644 --- a/argmin/src/solver/linesearch/morethuente.rs +++ b/argmin/src/solver/linesearch/morethuente.rs @@ -10,8 +10,8 @@ #![allow(clippy::nonminimal_bool)] use crate::core::{ - ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, SerializeAlias, - Solver, State, TerminationReason, KV, + ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, State, + TerminationReason, KV, }; use argmin_math::{ArgminDot, ArgminScaledAdd}; #[cfg(feature = "serde1")] @@ -299,8 +299,8 @@ where impl Solver> for MoreThuenteLineSearch where O: CostFunction + Gradient, - P: Clone + SerializeAlias + ArgminDot + ArgminScaledAdd, - G: Clone + SerializeAlias + ArgminDot, + P: Clone + ArgminDot + ArgminScaledAdd, + G: Clone + ArgminDot, F: ArgminFloat, { const NAME: &'static str = "More-Thuente Line search"; diff --git a/argmin/src/solver/neldermead/mod.rs b/argmin/src/solver/neldermead/mod.rs index f530cbd9e..9df22f908 100644 --- a/argmin/src/solver/neldermead/mod.rs +++ b/argmin/src/solver/neldermead/mod.rs @@ -19,8 +19,8 @@ //! use crate::core::{ - ArgminFloat, CostFunction, Error, IterState, Problem, SerializeAlias, Solver, - TerminationReason, TerminationStatus, KV, + ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason, + TerminationStatus, KV, }; use argmin_math::{ArgminAdd, ArgminMul, ArgminSub}; #[cfg(feature = "serde1")] @@ -319,7 +319,7 @@ impl fmt::Display for Action { impl Solver> for NelderMead where O: CostFunction, - P: Clone + SerializeAlias + ArgminSub + ArgminAdd + ArgminMul, + P: Clone + ArgminSub + ArgminAdd + ArgminMul, F: ArgminFloat + std::iter::Sum, { const NAME: &'static str = "Nelder-Mead method"; diff --git a/argmin/src/solver/newton/newton_cg.rs b/argmin/src/solver/newton/newton_cg.rs index c747028d1..2edd12523 100644 --- a/argmin/src/solver/newton/newton_cg.rs +++ b/argmin/src/solver/newton/newton_cg.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, DeserializeOwnedAlias, Error, Executor, Gradient, Hessian, IterState, LineSearch, - Operator, OptimizationResult, Problem, SerializeAlias, Solver, State, TerminationReason, - TerminationStatus, KV, + ArgminFloat, Error, Executor, Gradient, Hessian, IterState, LineSearch, Operator, + OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV, }; use crate::solver::conjugategradient::ConjugateGradient; use argmin_math::{ @@ -110,16 +109,14 @@ impl Solver> for NewtonCG + Hessian, P: Clone - + SerializeAlias - + DeserializeOwnedAlias + ArgminSub + ArgminDot + ArgminScaledAdd + ArgminMul + ArgminConj + ArgminZeroLike, - G: SerializeAlias + DeserializeOwnedAlias + ArgminL2Norm + ArgminMul, - H: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminDot, + G: ArgminL2Norm + ArgminMul, + H: Clone + ArgminDot, L: Clone + LineSearch + Solver>, F: ArgminFloat + ArgminL2Norm, { diff --git a/argmin/src/solver/particleswarm/mod.rs b/argmin/src/solver/particleswarm/mod.rs index 83a3ef932..40f656096 100644 --- a/argmin/src/solver/particleswarm/mod.rs +++ b/argmin/src/solver/particleswarm/mod.rs @@ -21,8 +21,7 @@ //! \[1\] use crate::core::{ - ArgminFloat, CostFunction, Error, PopulationState, Problem, SerializeAlias, Solver, SyncAlias, - KV, + ArgminFloat, CostFunction, Error, PopulationState, Problem, Solver, SyncAlias, KV, }; use argmin_math::{ArgminAdd, ArgminMinMax, ArgminMul, ArgminRandom, ArgminSub, ArgminZeroLike}; use rand::{Rng, SeedableRng}; @@ -279,8 +278,7 @@ where impl Solver, F>> for ParticleSwarm where O: CostFunction + SyncAlias, - P: SerializeAlias - + Clone + P: Clone + SyncAlias + ArgminAdd + ArgminSub diff --git a/argmin/src/solver/quasinewton/bfgs.rs b/argmin/src/solver/quasinewton/bfgs.rs index c1f549c3f..88ce72798 100644 --- a/argmin/src/solver/quasinewton/bfgs.rs +++ b/argmin/src/solver/quasinewton/bfgs.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, TerminationReason, - TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, + OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV, }; use argmin_math::{ ArgminAdd, ArgminDot, ArgminEye, ArgminL2Norm, ArgminMul, ArgminSub, ArgminTranspose, @@ -134,22 +133,9 @@ where impl Solver> for BFGS where O: CostFunction + Gradient, - P: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminSub - + ArgminDot - + ArgminDot, - G: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminL2Norm - + ArgminMul - + ArgminDot - + ArgminSub, - H: SerializeAlias - + DeserializeOwnedAlias - + ArgminSub + P: Clone + ArgminSub + ArgminDot + ArgminDot, + G: Clone + ArgminL2Norm + ArgminMul + ArgminDot + ArgminSub, + H: ArgminSub + ArgminDot + ArgminDot + ArgminAdd diff --git a/argmin/src/solver/quasinewton/dfp.rs b/argmin/src/solver/quasinewton/dfp.rs index f4d6ffa57..f3e7bb977 100644 --- a/argmin/src/solver/quasinewton/dfp.rs +++ b/argmin/src/solver/quasinewton/dfp.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, TerminationReason, - TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, + OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV, }; use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminSub}; #[cfg(feature = "serde1")] @@ -99,26 +98,9 @@ where impl Solver> for DFP where O: CostFunction + Gradient, - P: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminSub - + ArgminDot - + ArgminDot - + ArgminMul, - G: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminSub - + ArgminL2Norm - + ArgminDot, - H: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminSub - + ArgminDot - + ArgminAdd - + ArgminMul, + P: Clone + ArgminSub + ArgminDot + ArgminDot + ArgminMul, + G: Clone + ArgminSub + ArgminL2Norm + ArgminDot, + H: Clone + ArgminSub + ArgminDot + ArgminAdd + ArgminMul, L: Clone + LineSearch + Solver>, F: ArgminFloat, { diff --git a/argmin/src/solver/quasinewton/lbfgs.rs b/argmin/src/solver/quasinewton/lbfgs.rs index 7596db603..402eed7a2 100644 --- a/argmin/src/solver/quasinewton/lbfgs.rs +++ b/argmin/src/solver/quasinewton/lbfgs.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, TerminationReason, - TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, + OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV, }; use argmin_math::{ ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum, @@ -306,9 +305,6 @@ impl Solver> for LBFGS + Gradient, P: Clone - + std::fmt::Debug - + SerializeAlias - + DeserializeOwnedAlias + ArgminSub + ArgminSub + ArgminAdd @@ -322,9 +318,6 @@ where + ArgminZeroLike + ArgminMinMax, G: Clone - + std::fmt::Debug - + SerializeAlias - + DeserializeOwnedAlias + ArgminL2Norm + ArgminSub + ArgminAdd diff --git a/argmin/src/solver/quasinewton/sr1.rs b/argmin/src/solver/quasinewton/sr1.rs index 0049d53e3..92ee1a7e4 100644 --- a/argmin/src/solver/quasinewton/sr1.rs +++ b/argmin/src/solver/quasinewton/sr1.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, - LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, TerminationReason, - TerminationStatus, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, + OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV, }; use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminSub}; #[cfg(feature = "serde1")] @@ -149,26 +148,14 @@ impl Solver> for SR1 where O: CostFunction + Gradient, P: Clone - + SerializeAlias - + DeserializeOwnedAlias + ArgminSub + ArgminDot + ArgminDot + ArgminDot + ArgminL2Norm + ArgminMul, - G: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminSub - + ArgminL2Norm - + ArgminSub, - H: SerializeAlias - + DeserializeOwnedAlias - + ArgminDot - + ArgminDot - + ArgminAdd - + ArgminMul, + G: Clone + ArgminSub + ArgminL2Norm + ArgminSub, + H: ArgminDot + ArgminDot + ArgminAdd + ArgminMul, L: Clone + LineSearch + Solver>, F: ArgminFloat, { diff --git a/argmin/src/solver/quasinewton/sr1_trustregion.rs b/argmin/src/solver/quasinewton/sr1_trustregion.rs index 09b92cc70..bb55657e4 100644 --- a/argmin/src/solver/quasinewton/sr1_trustregion.rs +++ b/argmin/src/solver/quasinewton/sr1_trustregion.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, Hessian, - IterState, OptimizationResult, Problem, SerializeAlias, Solver, TerminationReason, - TerminationStatus, TrustRegionRadius, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, Hessian, IterState, OptimizationResult, + Problem, Solver, TerminationReason, TerminationStatus, TrustRegionRadius, KV, }; use argmin_math::{ ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminSub, ArgminWeightedDot, ArgminZeroLike, @@ -186,26 +185,14 @@ where + Gradient + Hessian, P: Clone - + SerializeAlias - + DeserializeOwnedAlias + ArgminSub + ArgminAdd + ArgminDot + ArgminDot + ArgminL2Norm + ArgminZeroLike, - G: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminL2Norm - + ArgminDot - + ArgminSub, - B: Clone - + SerializeAlias - + DeserializeOwnedAlias - + ArgminDot - + ArgminAdd - + ArgminMul, + G: Clone + ArgminL2Norm + ArgminDot + ArgminSub, + B: Clone + ArgminDot + ArgminAdd + ArgminMul, R: Clone + TrustRegionRadius + Solver>, F: ArgminFloat + ArgminL2Norm, { diff --git a/argmin/src/solver/simulatedannealing/mod.rs b/argmin/src/solver/simulatedannealing/mod.rs index d677739e5..e720f03c7 100644 --- a/argmin/src/solver/simulatedannealing/mod.rs +++ b/argmin/src/solver/simulatedannealing/mod.rs @@ -19,8 +19,8 @@ //! DOI: 10.1126/science.220.4598.671 use crate::core::{ - ArgminFloat, CostFunction, Error, IterState, Problem, SerializeAlias, Solver, - TerminationReason, TerminationStatus, KV, + ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason, + TerminationStatus, KV, }; use rand::prelude::*; use rand_xoshiro::Xoshiro256PlusPlus; @@ -440,7 +440,7 @@ where O: CostFunction + Anneal, P: Clone, F: ArgminFloat, - R: Rng + SerializeAlias, + R: Rng, { const NAME: &'static str = "Simulated Annealing"; fn init( diff --git a/argmin/src/solver/trustregion/steihaug.rs b/argmin/src/solver/trustregion/steihaug.rs index 512d9eaf4..cb6134e35 100644 --- a/argmin/src/solver/trustregion/steihaug.rs +++ b/argmin/src/solver/trustregion/steihaug.rs @@ -6,8 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, Error, IterState, Problem, SerializeAlias, Solver, State, TerminationReason, - TerminationStatus, TrustRegionRadius, KV, + ArgminFloat, Error, IterState, Problem, Solver, State, TerminationReason, TerminationStatus, + TrustRegionRadius, KV, }; use argmin_math::{ ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminWeightedDot, ArgminZeroLike, @@ -181,7 +181,6 @@ where impl Solver> for Steihaug where P: Clone - + SerializeAlias + ArgminMul + ArgminL2Norm + ArgminDot diff --git a/argmin/src/solver/trustregion/trustregion_method.rs b/argmin/src/solver/trustregion/trustregion_method.rs index 03e776d47..de1c3a316 100644 --- a/argmin/src/solver/trustregion/trustregion_method.rs +++ b/argmin/src/solver/trustregion/trustregion_method.rs @@ -6,9 +6,8 @@ // copied, modified, or distributed except according to those terms. use crate::core::{ - ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, Hessian, - IterState, OptimizationResult, Problem, SerializeAlias, Solver, TerminationStatus, - TrustRegionRadius, KV, + ArgminFloat, CostFunction, Error, Executor, Gradient, Hessian, IterState, OptimizationResult, + Problem, Solver, TerminationStatus, TrustRegionRadius, KV, }; use crate::solver::trustregion::reduction_ratio; use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminWeightedDot}; @@ -162,16 +161,9 @@ where O: CostFunction + Gradient + Hessian, - P: Clone - + std::fmt::Debug - + SerializeAlias - + DeserializeOwnedAlias - + ArgminL2Norm - + ArgminDot - + ArgminDot - + ArgminAdd, - G: Clone + SerializeAlias + DeserializeOwnedAlias, - H: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminDot, + P: Clone + ArgminL2Norm + ArgminDot + ArgminDot + ArgminAdd, + G: Clone, + H: Clone + ArgminDot, R: Clone + TrustRegionRadius + Solver>, F: ArgminFloat, { diff --git a/checkpointing/file/.gitignore b/checkpointing/file/.gitignore new file mode 100644 index 000000000..99c56abe3 --- /dev/null +++ b/checkpointing/file/.gitignore @@ -0,0 +1 @@ +.checkpoints diff --git a/checkpointing/file/Cargo.toml b/checkpointing/file/Cargo.toml new file mode 100644 index 000000000..6f5c0d706 --- /dev/null +++ b/checkpointing/file/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "argmin-checkpointing-file" +version = "0.1.0" +authors = ["Stefan Kroboth "] +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Checkpointing to a file for argmin" +documentation = "https://docs.rs/argmin-checkpointing-file/" +homepage = "https://argmin-rs.org" +repository = "https://github.com/argmin-rs/argmin" +readme = "README.md" +keywords = ["optimization", "math", "science"] +categories = ["science"] +exclude = [] + +[dependencies] +argmin = { version = "0.9.0", path = "../../argmin", default-features = false } +bincode = "1.3.3" +serde = "1.0.195" diff --git a/argmin/src/core/checkpointing/file.rs b/checkpointing/file/src/lib.rs similarity index 89% rename from argmin/src/core/checkpointing/file.rs rename to checkpointing/file/src/lib.rs index d5d4cebd3..5218910fd 100644 --- a/argmin/src/core/checkpointing/file.rs +++ b/checkpointing/file/src/lib.rs @@ -5,8 +5,9 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use crate::core::checkpointing::{Checkpoint, CheckpointingFrequency}; -use crate::core::{DeserializeOwnedAlias, Error, SerializeAlias}; +pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency}; +use argmin::core::Error; +use serde::{de::DeserializeOwned, Serialize}; use std::default::Default; use std::fs::File; use std::io::{BufReader, BufWriter}; @@ -55,7 +56,8 @@ impl FileCheckpoint { /// # Example /// /// ``` - /// use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency}; + /// use argmin::core::checkpointing::CheckpointingFrequency; + /// use argmin_checkpointing_file::FileCheckpoint; /// # use std::path::PathBuf; /// /// let directory = "checkpoints"; @@ -79,8 +81,8 @@ impl FileCheckpoint { impl Checkpoint for FileCheckpoint where - S: SerializeAlias + DeserializeOwnedAlias, - I: SerializeAlias + DeserializeOwnedAlias, + S: Serialize + DeserializeOwned, + I: Serialize + DeserializeOwned, { /// Writes checkpoint to disk. /// @@ -91,7 +93,8 @@ where /// # Example /// /// ``` - /// use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency, Checkpoint}; + /// use argmin::core::checkpointing::CheckpointingFrequency; + /// use argmin_checkpointing_file::FileCheckpoint; /// /// # use std::fs::File; /// # use std::io::BufReader; @@ -126,7 +129,8 @@ where /// # Example /// /// ``` - /// use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency, Checkpoint}; + /// use argmin::core::checkpointing::CheckpointingFrequency; + /// use argmin_checkpointing_file::FileCheckpoint /// # use argmin::core::Error; /// /// # use std::fs::File; @@ -164,18 +168,17 @@ where /// Returns the how often a checkpoint is to be saved. /// - /// Used internally by [`save_cond`](`crate::core::checkpointing::Checkpoint::save_cond`). + /// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`). fn frequency(&self) -> CheckpointingFrequency { self.frequency } } #[cfg(test)] -#[cfg(feature = "serde1")] mod tests { use super::*; - use crate::core::test_utils::TestSolver; - use crate::core::{IterState, State}; + use argmin::core::test_utils::TestSolver; + use argmin::core::{IterState, State}; #[test] #[allow(clippy::type_complexity)] diff --git a/examples/checkpoint/.gitignore b/examples/checkpoint/.gitignore new file mode 100644 index 000000000..99c56abe3 --- /dev/null +++ b/examples/checkpoint/.gitignore @@ -0,0 +1 @@ +.checkpoints diff --git a/examples/checkpoint/Cargo.toml b/examples/checkpoint/Cargo.toml index 436f5e42b..baab731c6 100644 --- a/examples/checkpoint/Cargo.toml +++ b/examples/checkpoint/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] argmin = { version = "*", path = "../../argmin", features = ["serde1"] } +argmin-checkpointing-file = { version = "*", path = "../../checkpointing/file" } argmin-math = { version = "*", features = ["vec"], path = "../../argmin-math" } argmin-observer-slog = { version = "*", path = "../../observers/slog/" } argmin_testfunctions = "*" diff --git a/examples/checkpoint/src/main.rs b/examples/checkpoint/src/main.rs index 9cc014cac..0d1e5e863 100644 --- a/examples/checkpoint/src/main.rs +++ b/examples/checkpoint/src/main.rs @@ -7,12 +7,12 @@ use argmin::{ core::{ - checkpointing::{CheckpointingFrequency, FileCheckpoint}, - observers::ObserverMode, - CostFunction, Error, Executor, Gradient, + checkpointing::CheckpointingFrequency, observers::ObserverMode, CostFunction, Error, + Executor, Gradient, }, solver::landweber::Landweber, }; +use argmin_checkpointing_file::FileCheckpoint; use argmin_observer_slog::SlogLogger; use argmin_testfunctions::{rosenbrock_2d, rosenbrock_2d_derivative}; From 7d987580b8e73fa5bed9ff6ea4feecbcd57e9d7c Mon Sep 17 00:00:00 2001 From: Stefan Kroboth Date: Thu, 18 Jan 2024 12:09:12 +0100 Subject: [PATCH 2/2] Updated actions/cache to v4 --- .github/workflows/book.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/book.yml b/.github/workflows/book.yml index c0631bd3c..8da481fd6 100644 --- a/.github/workflows/book.yml +++ b/.github/workflows/book.yml @@ -18,7 +18,7 @@ jobs: - name: Cache dependencies id: cache-dependencies - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: | ~/.cargo/registry