Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dedicated crate for FileCheckpoint, removed SerializeAlias and DeserializeOwnedAlias #395

Merged
merged 2 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/book.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:

- name: Cache dependencies
id: cache-dependencies
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ resolver = "2"
members = [
"argmin",
"argmin-math",
"checkpointing/*",
"observers/*",
"tools/spectator",

Expand Down
4 changes: 2 additions & 2 deletions argmin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ rand_xoshiro = "0.6.0"
thiserror = "1.0"
argmin-math = { path = "../argmin-math", version = "0.3", default-features = false, features = ["primitives"] }
# optional
bincode = { version = "1.3.3", optional = true }
ctrlc = { version = "3.2.4", optional = true }
getrandom = { version = "0.2", optional = true }
rayon = { version = "1.6.0", optional = true }
Expand All @@ -39,11 +38,12 @@ ndarray-linalg = { version = "0.16", features = ["intel-mkl-static"] }
argmin-math = { path = "../argmin-math", version = "0.3", features = ["vec"] }
argmin-observer-slog = { path = "../observers/slog" }
argmin-observer-paramwriter = { path = "../observers/paramwriter" }
argmin-checkpointing-file = { path = "../checkpointing/file" }

[features]
default = []
wasm-bindgen = ["instant/wasm-bindgen", "getrandom/js"]
serde1 = ["serde", "bincode", "rand_xoshiro/serde1"]
serde1 = ["serde", "rand_xoshiro/serde1"]
_ndarrayl = ["argmin-math/ndarray_latest"]
# When adding new features, please consider adding them to either `full` (for users)
# or `_full_dev` (only for local development, testing and computing test coverage).
Expand Down
14 changes: 6 additions & 8 deletions argmin/src/core/checkpointing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
//! with a user-chosen frequency. Optimizations can then be resumed from a given checkpoint after a
//! crash.
//!
//! For saving checkpoints to disk, `FileCheckpoint` is provided.
//! For saving checkpoints to disk, `FileCheckpoint` is provided in the `argmin-checkpointing-file`
//! crate.
//! Via the `Checkpoint` trait other checkpointing approaches can be implemented.
//!
//! The `CheckpointingFrequency` defines how often checkpoints are saved and can be chosen to be
Expand All @@ -29,7 +30,9 @@
//! # extern crate argmin_testfunctions;
//! # use argmin::core::{CostFunction, Error, Executor, Gradient, observers::ObserverMode};
//! # #[cfg(feature = "serde1")]
//! # use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency};
//! use argmin::core::checkpointing::CheckpointingFrequency;
//! # #[cfg(feature = "serde1")]
//! use argmin_checkpointing_file::FileCheckpoint;
//! # use argmin_observer_slog::SlogLogger;
//! # use argmin::solver::landweber::Landweber;
//! # use argmin_testfunctions::{rosenbrock_2d, rosenbrock_2d_derivative};
Expand Down Expand Up @@ -70,6 +73,7 @@
//! #
//! # let iters = 35;
//! # let solver = Landweber::new(0.001);
//!
//! // [...]
//!
//! # #[cfg(feature = "serde1")]
Expand Down Expand Up @@ -98,12 +102,6 @@
//! # }
//! ```

#[cfg(feature = "serde1")]
mod file;

#[cfg(feature = "serde1")]
pub use crate::core::checkpointing::file::FileCheckpoint;

use crate::core::Error;
use std::default::Default;
use std::fmt::Display;
Expand Down
74 changes: 60 additions & 14 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
use crate::core::checkpointing::Checkpoint;
use crate::core::observers::{Observe, ObserverMode, Observers};
use crate::core::{
DeserializeOwnedAlias, Error, OptimizationResult, Problem, SerializeAlias, Solver, State,
TerminationReason, TerminationStatus, KV,
Error, OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV,
};
use instant;
use std::sync::atomic::{AtomicBool, Ordering};
Expand All @@ -36,7 +35,7 @@ pub struct Executor<O, S, I> {
impl<O, S, I> Executor<O, S, I>
where
S: Solver<O, I>,
I: State + SerializeAlias + DeserializeOwnedAlias,
I: State,
{
/// Constructs an `Executor` from a user defined problem and a solver.
///
Expand Down Expand Up @@ -313,7 +312,8 @@ where
/// ```
/// # use argmin::core::{Error, Executor};
/// # #[cfg(feature = "serde1")]
/// # use argmin::core::checkpointing::{FileCheckpoint, CheckpointingFrequency};
/// # use argmin::core::checkpointing::CheckpointingFrequency;
/// # use argmin_checkpointing_file::FileCheckpoint;
/// # use argmin::core::test_utils::{TestSolver, TestProblem};
/// #
/// # fn main() -> Result<(), Error> {
Expand Down Expand Up @@ -570,17 +570,63 @@ mod tests {

/// The solver's `init` should not be called when started from a checkpoint.
/// See https://github.com/argmin-rs/argmin/issues/199.
// #[cfg(feature = "serde1")]
#[test]
#[cfg(feature = "serde1")]
fn test_checkpointing_solver_initialization() {
use crate::core::checkpointing::{CheckpointingFrequency, FileCheckpoint};
use crate::core::test_utils::TestProblem;
use crate::core::{ArgminFloat, CostFunction};
use std::cell::RefCell;

use crate::core::{
checkpointing::CheckpointingFrequency, test_utils::TestProblem, ArgminFloat,
CostFunction,
};
use serde::{Deserialize, Serialize};

// Fake optimization algorithm which holds internal state which changes over time
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct FakeCheckpoint {
pub frequency: CheckpointingFrequency,
pub solver: RefCell<Option<OptimizationAlgorithm>>,
pub state: RefCell<Option<IterState<Vec<f64>, (), (), (), (), f64>>>,
}

impl Checkpoint<OptimizationAlgorithm, IterState<Vec<f64>, (), (), (), (), f64>>
for FakeCheckpoint
{
fn save(
&self,
solver: &OptimizationAlgorithm,
state: &IterState<Vec<f64>, (), (), (), (), f64>,
) -> Result<(), Error> {
*self.solver.borrow_mut() = Some(solver.clone());
*self.state.borrow_mut() = Some(state.clone());
Ok(())
}

fn load(
&self,
) -> Result<
Option<(
OptimizationAlgorithm,
IterState<Vec<f64>, (), (), (), (), f64>,
)>,
Error,
> {
if self.solver.borrow().is_none() {
return Ok(None);
}
Ok(Some((
self.solver.borrow().clone().unwrap(),
self.state.borrow().clone().unwrap(),
)))
}

fn frequency(&self) -> CheckpointingFrequency {
self.frequency
}
}

// Fake optimization algorithm which holds internal state which changes over time
#[derive(Clone, Serialize, Deserialize)]
struct OptimizationAlgorithm {
pub internal_state: u64,
}
Expand Down Expand Up @@ -638,12 +684,12 @@ mod tests {
// solver instance
let solver = OptimizationAlgorithm { internal_state: 0 };

// Delete old checkpointing file
let _ = std::fs::remove_file(".checkpoints/init_test.arg");

// Create a checkpoint
let checkpoint =
FileCheckpoint::new(".checkpoints", "init_test", CheckpointingFrequency::Always);
let checkpoint = FakeCheckpoint {
frequency: CheckpointingFrequency::Always,
solver: RefCell::new(None),
state: RefCell::new(None),
};

// Create and run executor
let executor = Executor::new(problem, solver)
Expand Down
6 changes: 1 addition & 5 deletions argmin/src/core/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use crate::core::{kv::KvValue, DeserializeOwnedAlias, SendAlias, SerializeAlias};
use crate::core::{kv::KvValue, SendAlias};
use num_traits::{Float, FloatConst, FromPrimitive, ToPrimitive};
use std::fmt::{Debug, Display};

Expand All @@ -21,8 +21,6 @@ pub trait ArgminFloat:
+ ToPrimitive
+ Debug
+ Display
+ SerializeAlias
+ DeserializeOwnedAlias
+ SendAlias
+ Into<KvValue>
{
Expand All @@ -37,8 +35,6 @@ impl<I> ArgminFloat for I where
+ ToPrimitive
+ Debug
+ Display
+ SerializeAlias
+ DeserializeOwnedAlias
+ SendAlias
+ Into<KvValue>
{
Expand Down
3 changes: 0 additions & 3 deletions argmin/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ mod parallelization;
mod problem;
/// Definition of the return type of the solvers
mod result;
/// Trait alias for `serde`s `Serialize` and `DeserializeOwned`
mod serialization;
/// `Solver` trait
mod solver;
/// iteration state
Expand All @@ -51,7 +49,6 @@ pub use kv::{KvValue, KV};
pub use parallelization::{SendAlias, SyncAlias};
pub use problem::{CostFunction, Gradient, Hessian, Jacobian, LinearProgram, Operator, Problem};
pub use result::OptimizationResult;
pub use serialization::{DeserializeOwnedAlias, SerializeAlias};
pub use solver::Solver;
pub use state::{IterState, LinearProgramState, PopulationState, State};
pub use termination::{TerminationReason, TerminationStatus};
53 changes: 0 additions & 53 deletions argmin/src/core/serialization.rs

This file was deleted.

2 changes: 1 addition & 1 deletion argmin/src/core/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// ([`terminate`](`Solver::terminate`) and [`terminate_internal`](`Solver::terminate_internal`)).
/// Only `next_iter` is mandatory to implement, all others provide default implementations.
///
/// A `Solver` needs to be serializable.
/// A `Solver` should be (de)serializable in order to work with checkpointing.
///
/// # Example
///
Expand Down
4 changes: 2 additions & 2 deletions argmin/src/solver/conjugategradient/beta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
//! \[0\] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
//! Springer. ISBN 0-387-30303-0.

use crate::core::{ArgminFloat, SerializeAlias};
use crate::core::ArgminFloat;
use argmin_math::{ArgminDot, ArgminL2Norm, ArgminSub};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
Expand All @@ -44,7 +44,7 @@ use serde::{Deserialize, Serialize};
/// }
/// }
/// ```
pub trait NLCGBetaUpdate<G, P, F>: SerializeAlias {
pub trait NLCGBetaUpdate<G, P, F> {
/// Update beta.
///
/// # Parameters
Expand Down
5 changes: 1 addition & 4 deletions argmin/src/solver/conjugategradient/cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use crate::core::{
ArgminFloat, Error, IterState, Operator, Problem, SerializeAlias, Solver, State, KV,
};
use crate::core::{ArgminFloat, Error, IterState, Operator, Problem, Solver, State, KV};
use argmin_math::{ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -93,7 +91,6 @@ impl<P, O, F> Solver<O, IterState<P, (), (), (), (), F>> for ConjugateGradient<P
where
O: Operator<Param = P, Output = P>,
P: Clone
+ SerializeAlias
+ ArgminDot<P, F>
+ ArgminSub<P, P>
+ ArgminScaledAdd<P, F, P>
Expand Down
13 changes: 4 additions & 9 deletions argmin/src/solver/conjugategradient/nonlinear_cg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
// copied, modified, or distributed except according to those terms.

use crate::core::{
ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState,
LineSearch, NLCGBetaUpdate, OptimizationResult, Problem, SerializeAlias, Solver, State, KV,
ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, NLCGBetaUpdate,
OptimizationResult, Problem, Solver, State, KV,
};
use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul};
#[cfg(feature = "serde1")]
Expand Down Expand Up @@ -122,13 +122,8 @@ impl<O, P, G, L, B, F> Solver<O, IterState<P, G, (), (), (), F>>
for NonlinearConjugateGradient<P, L, B, F>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone + SerializeAlias + DeserializeOwnedAlias + ArgminAdd<P, P> + ArgminMul<F, P>,
G: Clone
+ SerializeAlias
+ DeserializeOwnedAlias
+ ArgminMul<F, P>
+ ArgminDot<G, F>
+ ArgminL2Norm<F>,
P: Clone + ArgminAdd<P, P> + ArgminMul<F, P>,
G: Clone + ArgminMul<F, P> + ArgminDot<G, F> + ArgminL2Norm<F>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
B: NLCGBetaUpdate<G, P, F>,
F: ArgminFloat,
Expand Down
Loading
Loading