diff --git a/argmin/Cargo.toml b/argmin/Cargo.toml index c5a3058c0..339c31711 100644 --- a/argmin/Cargo.toml +++ b/argmin/Cargo.toml @@ -61,10 +61,6 @@ maintenance = { status = "actively-developed" } targets = ["x86_64-unknown-linux-gnu"] features = ["serde1"] -[[example]] -name = "bfgs" -required-features = ["argmin-math/ndarray_latest-serde"] - [[example]] name = "brentroot" required-features = [] diff --git a/examples/backtracking/Cargo.toml b/examples/backtracking/Cargo.toml index 6967e980c..11958a6dd 100644 --- a/examples/backtracking/Cargo.toml +++ b/examples/backtracking/Cargo.toml @@ -5,13 +5,8 @@ edition = "2021" license = "MIT OR Apache-2.0" publish = false -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] argmin = { version = "*", path = "../../argmin" } argmin-math = { version = "*", features = ["vec"], path = "../../argmin-math" } argmin-observer-slog = { version = "*", path = "../../observers/slog/" } argmin_testfunctions = "*" - -[features] -wasm-bindgen = ["argmin/wasm-bindgen"] diff --git a/examples/bfgs/Cargo.toml b/examples/bfgs/Cargo.toml new file mode 100644 index 000000000..196459cc4 --- /dev/null +++ b/examples/bfgs/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-bfgs" +version = "0.1.0" +edition = "2021" +license = "MIT OR Apache-2.0" +publish = false + +[dependencies] +argmin = { version = "*", path = "../../argmin" } +argmin-math = { version = "*", features = ["ndarray_latest-nolinalg-serde"], path = "../../argmin-math" } +argmin-observer-slog = { version = "*", path = "../../observers/slog/" } +argmin_testfunctions = "*" +finitediff = { version = "0.1.4", features = ["ndarray"] } +ndarray = "0.15.6" diff --git a/examples/bfgs/src/main.rs b/examples/bfgs/src/main.rs new file mode 100644 index 000000000..2d0fc5f2b --- /dev/null +++ b/examples/bfgs/src/main.rs @@ -0,0 +1,79 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use argmin::{ + core::{observers::ObserverMode, CostFunction, Error, Executor, Gradient}, + solver::{linesearch::MoreThuenteLineSearch, quasinewton::BFGS}, +}; +use argmin_observer_slog::SlogLogger; +use argmin_testfunctions::rosenbrock; +use finitediff::FiniteDiff; +use ndarray::{array, Array1, Array2}; + +struct Rosenbrock { + a: f64, + b: f64, +} + +impl CostFunction for Rosenbrock { + type Param = Array1; + type Output = f64; + + fn cost(&self, p: &Self::Param) -> Result { + Ok(rosenbrock(&p.to_vec(), self.a, self.b)) + } +} +impl Gradient for Rosenbrock { + type Param = Array1; + type Gradient = Array1; + + fn gradient(&self, p: &Self::Param) -> Result { + Ok((*p).forward_diff(&|x| rosenbrock(&x.to_vec(), self.a, self.b))) + } +} + +fn run() -> Result<(), Error> { + // Define cost function + let cost = Rosenbrock { a: 1.0, b: 100.0 }; + + // Define initial parameter vector + // let init_param: Array1 = array![-1.2, 1.0]; + // let init_hessian: Array2 = Array2::eye(2); + let init_param: Array1 = array![-1.2, 1.0, -10.0, 2.0, 3.0, 2.0, 4.0, 10.0]; + let init_hessian: Array2 = Array2::eye(8); + + // set up a line search + let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9)?; + + // Set up solver + let solver = BFGS::new(linesearch); + + // Run solver + let res = Executor::new(cost, solver) + .configure(|state| { + state + .param(init_param) + .inv_hessian(init_hessian) + .max_iters(60) + }) + .add_observer(SlogLogger::term(), ObserverMode::Always) + .run()?; + + // Wait a second (lets the logger flush everything before printing again) + std::thread::sleep(std::time::Duration::from_secs(1)); + + // Print result + println!("{res}"); + Ok(()) +} + +fn main() { + if let Err(ref e) = run() { + println!("{e}"); + std::process::exit(1); + } +}