Skip to content

Commit

Permalink
Turn examples into crates
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Jan 15, 2024
1 parent 86a8eaa commit 53b027e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 9 deletions.
4 changes: 0 additions & 4 deletions argmin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,6 @@ maintenance = { status = "actively-developed" }
targets = ["x86_64-unknown-linux-gnu"]
features = ["serde1"]

[[example]]
name = "bfgs"
required-features = ["argmin-math/ndarray_latest-serde"]

[[example]]
name = "brentroot"
required-features = []
Expand Down
5 changes: 0 additions & 5 deletions examples/backtracking/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@ edition = "2021"
license = "MIT OR Apache-2.0"
publish = false

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
argmin = { version = "*", path = "../../argmin" }
argmin-math = { version = "*", features = ["vec"], path = "../../argmin-math" }
argmin-observer-slog = { version = "*", path = "../../observers/slog/" }
argmin_testfunctions = "*"

[features]
wasm-bindgen = ["argmin/wasm-bindgen"]
14 changes: 14 additions & 0 deletions examples/bfgs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "example-bfgs"
version = "0.1.0"
edition = "2021"
license = "MIT OR Apache-2.0"
publish = false

[dependencies]
argmin = { version = "*", path = "../../argmin" }
argmin-math = { version = "*", features = ["ndarray_latest-nolinalg-serde"], path = "../../argmin-math" }
argmin-observer-slog = { version = "*", path = "../../observers/slog/" }
argmin_testfunctions = "*"
finitediff = { version = "0.1.4", features = ["ndarray"] }
ndarray = "0.15.6"
79 changes: 79 additions & 0 deletions examples/bfgs/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use argmin::{
core::{observers::ObserverMode, CostFunction, Error, Executor, Gradient},
solver::{linesearch::MoreThuenteLineSearch, quasinewton::BFGS},
};
use argmin_observer_slog::SlogLogger;
use argmin_testfunctions::rosenbrock;
use finitediff::FiniteDiff;
use ndarray::{array, Array1, Array2};

struct Rosenbrock {
a: f64,
b: f64,
}

impl CostFunction for Rosenbrock {
type Param = Array1<f64>;
type Output = f64;

fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
Ok(rosenbrock(&p.to_vec(), self.a, self.b))
}
}
impl Gradient for Rosenbrock {
type Param = Array1<f64>;
type Gradient = Array1<f64>;

fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
Ok((*p).forward_diff(&|x| rosenbrock(&x.to_vec(), self.a, self.b)))
}
}

fn run() -> Result<(), Error> {
// Define cost function
let cost = Rosenbrock { a: 1.0, b: 100.0 };

// Define initial parameter vector
// let init_param: Array1<f64> = array![-1.2, 1.0];
// let init_hessian: Array2<f64> = Array2::eye(2);
let init_param: Array1<f64> = array![-1.2, 1.0, -10.0, 2.0, 3.0, 2.0, 4.0, 10.0];
let init_hessian: Array2<f64> = Array2::eye(8);

// set up a line search
let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9)?;

// Set up solver
let solver = BFGS::new(linesearch);

// Run solver
let res = Executor::new(cost, solver)
.configure(|state| {
state
.param(init_param)
.inv_hessian(init_hessian)
.max_iters(60)
})
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;

// Wait a second (lets the logger flush everything before printing again)
std::thread::sleep(std::time::Duration::from_secs(1));

// Print result
println!("{res}");
Ok(())
}

fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}

0 comments on commit 53b027e

Please sign in to comment.