diff --git a/enzyme/benchmarks/ReverseMode/adbench/ba.h b/enzyme/benchmarks/ReverseMode/adbench/ba.h index 3ade86a0b7b2..5d9178120e76 100644 --- a/enzyme/benchmarks/ReverseMode/adbench/ba.h +++ b/enzyme/benchmarks/ReverseMode/adbench/ba.h @@ -127,6 +127,19 @@ extern "C" { double* reproj_err, double* w_err ); + + void rust2_ba_objective( + int n, + int m, + int p, + double const* cams, + double const* X, + double const* w, + int const* obs, + double const* feats, + double* reproj_err, + double* w_err + ); void dcompute_reproj_error( double const* cam, @@ -169,6 +182,20 @@ extern "C" { ); void adept_compute_zach_weight_error(double const* w, double* dw, double* err, double* derr); + + void rust_dcompute_reproj_error( + double const* cam, + double * dcam, + double const* X, + double * dX, + double const* w, + double * wb, + double const* feat, + double *err, + double *derr + ); + + void rust_dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr); } void read_ba_instance(const string& fn, @@ -486,9 +513,9 @@ int main(const int argc, const char* argv[]) { gettimeofday(&start, NULL); calculate_jacobian(input, result); gettimeofday(&end, NULL); - printf("Enzyme combined %0.6f\n", tdiff(&start, &end)); + printf("Enzyme c++ combined %0.6f\n", tdiff(&start, &end)); json enzyme; - enzyme["name"] = "Enzyme combined"; + enzyme["name"] = "Enzyme c++ combined"; enzyme["runtime"] = tdiff(&start, &end); for(unsigned i=0; i<5; i++) { printf("%f ", result.J.vals[i]); @@ -499,6 +526,125 @@ int main(const int argc, const char* argv[]) { } } + + { + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats); + + struct BAOutput result = { + std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p) + }; + + + { + struct timeval start, end; + gettimeofday(&start, NULL); + ba_objective( + input.n, + input.m, + input.p, + input.cams.data(), + input.X.data(), + input.w.data(), + input.obs.data(), + input.feats.data(), + result.reproj_err.data(), + result.w_err.data() + ); + gettimeofday(&end, NULL); + printf("primal c++ t=%0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "primal c++"; + enzyme["runtime"] = tdiff(&start, &end); + for(unsigned i=0; i<5; i++) { + printf("%f ", result.reproj_err[i]); + enzyme["result"].push_back(result.reproj_err[i]); + } + for(unsigned i=0; i<5; i++) { + printf("%f ", result.w_err[i]); + enzyme["result"].push_back(result.w_err[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + + { + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats); + + struct BAOutput result = { + std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p) + }; + { + + struct timeval start, end; + gettimeofday(&start, NULL); + rust2_ba_objective( + input.n, + input.m, + input.p, + input.cams.data(), + input.X.data(), + input.w.data(), + input.obs.data(), + input.feats.data(), + result.reproj_err.data(), + result.w_err.data() + ); + gettimeofday(&end, NULL); + printf("primal rust t=%0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "primal rust"; + enzyme["runtime"] = tdiff(&start, &end); + for(unsigned i=0; i<5; i++) { + printf("%f ", result.reproj_err[i]); + enzyme["result"].push_back(result.reproj_err[i]); + } + for(unsigned i=0; i<5; i++) { + printf("%f ", result.w_err[i]); + enzyme["result"].push_back(result.w_err[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } + + { + + struct BAInput input; + read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats); + + struct BAOutput result = { + std::vector(2 * input.p), + std::vector(input.p), + BASparseMat(input.n, input.m, input.p) + }; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Enzyme rust combined"; + enzyme["runtime"] = tdiff(&start, &end); + for(unsigned i=0; i<5; i++) { + printf("%f ", result.J.vals[i]); + enzyme["result"].push_back(result.J.vals[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + + } + test_suite["llvm-version"] = __clang_version__; test_suite["mode"] = "ReverseMode"; test_suite["batch-size"] = 1; diff --git a/enzyme/benchmarks/ReverseMode/adbench/gmm.h b/enzyme/benchmarks/ReverseMode/adbench/gmm.h index 00f4302b9f99..e0a277b19567 100644 --- a/enzyme/benchmarks/ReverseMode/adbench/gmm.h +++ b/enzyme/benchmarks/ReverseMode/adbench/gmm.h @@ -33,6 +33,17 @@ struct GMMParameters { }; extern "C" { +void gmm_objective( + int d, + int k, + int n, + double const* alphas, + double const* means, + double const* icf, + double const* x, + Wishart wishart, + double* err +); void dgmm_objective(int d, int k, int n, const double *alphas, double * alphasb, const double *means, double *meansb, const double *icf, double *icfb, const double *x, Wishart wishart, double *err, double * @@ -47,6 +58,15 @@ extern "C" { alphasb, const double *means, double *meansb, const double *icf, double *icfb, const double *x, Wishart wishart, double *err, double * errb); + + void rust_dgmm_objective(int d, int k, int n, const double *alphas, double * + alphasb, const double *means, double *meansb, const double *icf, + double *icfb, const double *x, Wishart &wishart, double *err, double * + errb); + + void rust_gmm_objective(int d, int k, int n, const double *alphas, + const double *means, const double *icf, + const double *x, Wishart &wishart, double *err); } void read_gmm_instance(const string& fn, @@ -123,10 +143,7 @@ void read_gmm_instance(const string& fn, fclose(fid); } -typedef void(*deriv_t)(int d, int k, int n, const double *alphas, double *alphasb, const double *means, double *meansb, const double *icf, - double *icfb, const double *x, Wishart wishart, double *err, double *errb); - -template +template void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result) { double* alphas_gradient_part = result.gradient.data(); @@ -159,6 +176,25 @@ void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result) ); } +template +double primal(struct GMMInput &input) +{ + double tmp = 0.0; // stores fictive result + // (Tapenade doesn't calculate an original function in reverse mode) + deriv( + input.d, + input.k, + input.n, + input.alphas.data(), + input.means.data(), + input.icf.data(), + input.x.data(), + input.wishart, + &tmp + ); + return tmp; +} + int main(const int argc, const char* argv[]) { printf("starting main\n"); @@ -167,9 +203,11 @@ int main(const int argc, const char* argv[]) { std::vector paths;// = { "1k/gmm_d10_K100.txt" }; - getTests(paths, "data/1k", "1k/"); - getTests(paths, "data/2.5k", "2.5k/"); - getTests(paths, "data/10k", "10k/"); + //getTests(paths, "data/1k", "1k/"); + //getTests(paths, "data/2.5k", "2.5k/"); + //getTests(paths, "data/10k", "10k/"); + //paths.push_back("1k/gmm_d128_K100.txt"); + paths.push_back("1k/gmm_d2_K5.txt"); std::ofstream jsonfile("results.json", std::ofstream::trunc); json test_results; @@ -257,6 +295,7 @@ int main(const int argc, const char* argv[]) { gettimeofday(&start, NULL); calculate_jacobian(input, result); gettimeofday(&end, NULL); + printf("Enzyme c++ combined %0.6f\n", tdiff(&start, &end)); json enzyme; enzyme["name"] = "Enzyme combined"; enzyme["runtime"] = tdiff(&start, &end); @@ -269,6 +308,61 @@ int main(const int argc, const char* argv[]) { test_suite["tools"].push_back(enzyme); } + } + + { + + struct GMMInput input; + read_gmm_instance("data/" + path, &input.d, &input.k, &input.n, + input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point); + + int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2; + + struct GMMOutput result = { 0, std::vector(Jcols) }; + + { + struct timeval start, end; + gettimeofday(&start, NULL); + auto res = primal(input); + gettimeofday(&end, NULL); + printf("c++ primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res); + + json primal; + primal["name"] = "C++ primal"; + primal["runtime"] = tdiff(&start, &end); + primal["result"].push_back(res); + test_suite["tools"].push_back(primal); + } + { + struct timeval start, end; + gettimeofday(&start, NULL); + auto res = primal(input); + gettimeofday(&end, NULL); + printf("rust primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res); + json primal; + primal["name"] = "Rust primal"; + primal["runtime"] = tdiff(&start, &end); + primal["result"].push_back(res); + test_suite["tools"].push_back(primal); + } + { + struct timeval start, end; + gettimeofday(&start, NULL); + calculate_jacobian(input, result); + gettimeofday(&end, NULL); + printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end)); + json enzyme; + enzyme["name"] = "Rust Enzyme combined"; + enzyme["runtime"] = tdiff(&start, &end); + for (unsigned i = result.gradient.size() - 5; + i < result.gradient.size(); i++) { + printf("%f ", result.gradient[i]); + enzyme["result"].push_back(result.gradient[i]); + } + printf("\n"); + test_suite["tools"].push_back(enzyme); + } + } test_suite["llvm-version"] = __clang_version__; test_suite["mode"] = "ReverseMode"; diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.lock b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock new file mode 100644 index 000000000000..74e2768e7cd4 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.lock @@ -0,0 +1,16 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "bars" +version = "0.1.0" +dependencies = [ + "libm", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" diff --git a/enzyme/benchmarks/ReverseMode/ba/Cargo.toml b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml new file mode 100644 index 000000000000..160c7716f3d8 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "bars" +version = "0.1.0" +edition = "2021" + + +[lib] +crate-type = ["cdylib"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[profile.release] +lto = "fat" + +[profile.dev] +lto = "fat" + +[dependencies] +libm = { version = "0.2.8", optional = true } diff --git a/enzyme/benchmarks/ReverseMode/ba/Makefile.make b/enzyme/benchmarks/ReverseMode/ba/Makefile.make index 6f0f2cc18242..8a13a0e524fb 100644 --- a/enzyme/benchmarks/ReverseMode/ba/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/ba/Makefile.make @@ -1,23 +1,17 @@ -# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" make -B ba-unopt.ll ba-raw.ll results.json -f %s +# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%newLoadClangEnzyme" make -B ba.o results.json -f %s .PHONY: clean +dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) + clean: rm -f *.ll *.o results.txt results.json -%-unopt.ll: %.cpp - clang++ $(BENCH) $^ -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -Xclang -new-struct-path-tbaa -o $@ -S -emit-llvm - #clang++ $(BENCH) $^ -O1 -Xclang -disable-llvm-passes -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -Xclang -new-struct-path-tbaa -o $@ -S -emit-llvm - -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) -enzyme -o $@ -S - -%-opt.ll: %-raw.ll - opt $^ -o $@ -S - #opt $^ -O2 -o $@ -S +$(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a: src/lib.rs Cargo.toml + ENZYME_LOOSE_TYPES=1 cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm -ba.o: ba-opt.ll - clang++ -O2 $^ -o $@ $(BENCHLINK) +ba.o: ba.cpp $(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a + clang++ $(LOAD) $(BENCH) ba.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o ba.o -lpthread $(BENCHLINK) -lm $(dir)/benchmarks/ReverseMode/ba/target/release/libbars.a -L /usr/lib/gcc/x86_64-linux-gnu/11 results.json: ba.o ./$^ diff --git a/enzyme/benchmarks/ReverseMode/ba/src/lib.rs b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs new file mode 100644 index 000000000000..82318144f63f --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/src/lib.rs @@ -0,0 +1,221 @@ +#![feature(autodiff)] +#![feature(slice_first_last_chunk)] +#![allow(non_snake_case)] + +//#define BA_NCAMPARAMS 11 +static BA_NCAMPARAMS: usize = 11; + +fn sqsum(x: &[f64]) -> f64 { + x.iter().map(|&v| v * v).sum() +} + +#[inline] +fn cross(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] { + [ + a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0], + ] +} + +fn radial_distort(rad_params: &[f64], proj: &mut [f64]) { + let rsq = sqsum(proj); + let l = 1. + rad_params[0] * rsq + rad_params[1] * rsq * rsq; + proj[0] = proj[0] * l; + proj[1] = proj[1] * l; +} + +fn rodrigues_rotate_point(rot: &[f64; 3], pt: &[f64; 3], rotated_pt: &mut [f64; 3]) { + let sqtheta = sqsum(rot); + if sqtheta != 0. { + let theta = sqtheta.sqrt(); + let costheta = theta.cos(); + let sintheta = theta.sin(); + let theta_inverse = 1. / theta; + let w = rot.map(|v| v * theta_inverse); + let w_cross_pt = cross(&w, &pt); + let tmp = (w[0] * pt[0] + w[1] * pt[1] + w[2] * pt[2]) * (1. - costheta); + for i in 0..3 { + rotated_pt[i] = pt[i] * costheta + w_cross_pt[i] * sintheta + w[i] * tmp; + } + } else { + let rot_cross_pt = cross(&rot, &pt); + for i in 0..3 { + rotated_pt[i] = pt[i] + rot_cross_pt[i]; + } + } +} + +fn project(cam: &[f64; 11], X: &[f64; 3], proj: &mut [f64; 2]) { + let C = &cam[3..6]; + let mut Xo = [0.; 3]; + let mut Xcam = [0.; 3]; + + Xo[0] = X[0] - C[0]; + Xo[1] = X[1] - C[1]; + Xo[2] = X[2] - C[2]; + + rodrigues_rotate_point(cam.first_chunk::<3>().unwrap(), &Xo, &mut Xcam); + + proj[0] = Xcam[0] / Xcam[2]; + proj[1] = Xcam[1] / Xcam[2]; + + radial_distort(&cam[9..], proj); + + proj[0] = proj[0] * cam[6] + cam[7]; + proj[1] = proj[1] * cam[6] + cam[8]; +} + +#[no_mangle] +pub extern "C" fn rust_dcompute_reproj_error( + cam: *const [f64; 11], + dcam: *mut [f64; 11], + x: *const [f64; 3], + dx: *mut [f64; 3], + w: *const [f64; 1], + wb: *mut [f64; 1], + feat: *const [f64; 2], + err: *mut [f64; 2], + derr: *mut [f64; 2], +) { + dcompute_reproj_error(cam, dcam, x, dx, w, wb, feat, err, derr); +} + +#[no_mangle] +pub extern "C" fn rust_dcompute_zach_weight_error( + w: *const f64, + dw: *mut f64, + err: *mut f64, + derr: *mut f64, +) { + dcompute_zach_weight_error(w, dw, err, derr); +} + +#[autodiff( + dcompute_reproj_error, + Reverse, + Duplicated, + Duplicated, + Duplicated, + Const, + Duplicated +)] +pub fn compute_reproj_error( + cam: *const [f64; 11], + x: *const [f64; 3], + w: *const [f64; 1], + feat: *const [f64; 2], + err: *mut [f64; 2], +) { + let cam = unsafe { &*cam }; + let w = unsafe { *(*w).get_unchecked(0) }; + let x = unsafe { &*x }; + let feat = unsafe { &*feat }; + let mut err = unsafe { &mut *err }; + let mut proj = [0.; 2]; + project(cam, x, &mut proj); + err[0] = w * (proj[0] - feat[0]); + err[1] = w * (proj[1] - feat[1]); +} + +#[autodiff(dcompute_zach_weight_error, Reverse, Duplicated, Duplicated)] +pub fn compute_zach_weight_error(w: *const f64, err: *mut f64) { + let w = unsafe { *w }; + let mut err = unsafe { *err }; + err = 1. - w * w; +} + +// n number of cameras +// m number of points +// p number of observations +// cams: 11*n cameras in format [r1 r2 r3 C1 C2 C3 f u0 v0 k1 k2] +// r1, r2, r3 are angle - axis rotation parameters(Rodrigues) +// [C1 C2 C3]' is the camera center +// f is the focal length in pixels +// [u0 v0]' is the principal point +// k1, k2 are radial distortion parameters +// X: 3*m points +// obs: 2*p observations (pairs cameraIdx, pointIdx) +// feats: 2*p features (x,y coordinates corresponding to observations) +// reproj_err: 2*p errors of observations +// w_err: p weight "error" terms +fn rust_ba_objective( + n: usize, + m: usize, + p: usize, + cams: &[f64], + x: &[f64], + w: &[f64], + obs: &[i32], + feats: &[f64], + reproj_err: &mut [f64], + w_err: &mut [f64], +) { + assert_eq!(cams.len(), n * 11); + assert_eq!(x.len(), m * 3); + assert_eq!(w.len(), p); + assert_eq!(obs.len(), p * 2); + assert_eq!(feats.len(), p * 2); + assert_eq!(reproj_err.len(), p * 2); + assert_eq!(w_err.len(), p); + + for i in 0..p { + let cam_idx = obs[i * 2 + 0] as usize; + let pt_idx = obs[i * 2 + 1] as usize; + let start = cam_idx * BA_NCAMPARAMS; + let cam: &[f64; 11] = unsafe { + cams[start..] + .get_unchecked(..11) + .try_into() + .unwrap_unchecked() + }; + let x: &[f64; 3] = unsafe { + x[pt_idx * 3..] + .get_unchecked(..3) + .try_into() + .unwrap_unchecked() + }; + let w: &[f64; 1] = unsafe { w[i..].get_unchecked(..1).try_into().unwrap_unchecked() }; + let feat: &[f64; 2] = unsafe { + feats[i * 2..] + .get_unchecked(..2) + .try_into() + .unwrap_unchecked() + }; + let reproj_err: &mut [f64; 2] = unsafe { + reproj_err[i * 2..] + .get_unchecked_mut(..2) + .try_into() + .unwrap_unchecked() + }; + compute_reproj_error(cam, x, w, feat, reproj_err); + } + + for i in 0..p { + let w_err: &mut f64 = unsafe { w_err.get_unchecked_mut(i) }; + compute_zach_weight_error(w[i..].as_ptr(), w_err as *mut f64); + } +} + +#[no_mangle] +extern "C" fn rust2_ba_objective( + n: usize, + m: usize, + p: usize, + cams: *const f64, + x: *const f64, + w: *const f64, + obs: *const i32, + feats: *const f64, + reproj_err: *mut f64, + w_err: *mut f64, +) { + let cams = unsafe { std::slice::from_raw_parts(cams, n * 11) }; + let x = unsafe { std::slice::from_raw_parts(x, m * 3) }; + let w = unsafe { std::slice::from_raw_parts(w, p) }; + let obs = unsafe { std::slice::from_raw_parts(obs, p * 2) }; + let feats = unsafe { std::slice::from_raw_parts(feats, p * 2) }; + let reproj_err = unsafe { std::slice::from_raw_parts_mut(reproj_err, p * 2) }; + let w_err = unsafe { std::slice::from_raw_parts_mut(w_err, p) }; + rust_ba_objective(n, m, p, cams, x, w, obs, feats, reproj_err, w_err); +} diff --git a/enzyme/benchmarks/ReverseMode/ba/src/main.rs b/enzyme/benchmarks/ReverseMode/ba/src/main.rs new file mode 100644 index 000000000000..13f221be69c1 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/ba/src/main.rs @@ -0,0 +1,26 @@ +use bars::{dcompute_reproj_error, dcompute_zach_weight_error}; +fn main() { + let cam = [0.0; 11]; + let mut dcam = [0.0; 11]; + let x = [0.0; 3]; + let mut dx = [0.0; 3]; + let w = [0.0; 1]; + let mut dw = [0.0; 1]; + let feat = [0.0; 2]; + let mut err = [0.0; 2]; + let mut derr = [0.0; 2]; + dcompute_reproj_error( + &cam as *const [f64;11], + &mut dcam as *mut [f64;11], + &x as *const [f64;3], + &mut dx as *mut [f64;3], + &w as *const [f64;1], + &mut dw as *mut [f64;1], + &feat as *const [f64;2], + &mut err as *mut [f64;2], + &mut derr as *mut [f64;2], + ); + + let mut wb = 0.0; + dcompute_zach_weight_error(&w as *const f64, &mut dw as *mut f64, &mut err as *mut f64, &mut derr as *mut f64); +} diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock new file mode 100644 index 000000000000..cfdab95b3d9c --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.lock @@ -0,0 +1,16 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "gmmrs" +version = "0.1.0" +dependencies = [ + "libm", +] + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" diff --git a/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml new file mode 100644 index 000000000000..818440d14aba --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "gmmrs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = ["lib"] + +[features] +libm = ["dep:libm"] + +[profile.release] +lto = "fat" +opt-level = 2 +# opt-level = 3 +#debug = true +#strip = "none" + +[profile.dev] +lto = "fat" + +[dependencies] +libm = { version = "0.2.8", optional = true } diff --git a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make index 5072679eeb0e..ece7323f92ed 100644 --- a/enzyme/benchmarks/ReverseMode/gmm/Makefile.make +++ b/enzyme/benchmarks/ReverseMode/gmm/Makefile.make @@ -1,23 +1,18 @@ -# RUN: if [ %llvmver -ge 12 ] || [ %llvmver -le 9 ]; then cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" make -B gmm-unopt.ll gmm-raw.ll results.json -f %s; fi +# RUN: if [ %llvmver -ge 12 ] || [ %llvmver -le 9 ]; then cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%newLoadClangEnzyme" make -B gmm.o results.json -f %s; fi .PHONY: clean +dir := $(abspath $(lastword $(MAKEFILE_LIST))/../../../..) + clean: rm -f *.ll *.o results.txt results.json -%-unopt.ll: %.cpp - clang++ $(BENCH) $^ -O2 -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm - #clang++ $(BENCH) $^ -O1 -Xclang -disable-llvm-passes -fno-vectorize -fno-slp-vectorize -ffast-math -fno-unroll-loops -o $@ -S -emit-llvm - -%-raw.ll: %-unopt.ll - opt $^ $(LOAD) -enzyme -o $@ -S - -%-opt.ll: %-raw.ll - opt $^ -o $@ -S - #opt $^ -O2 -o $@ -S - -gmm.o: gmm-opt.ll - clang++ -O2 $^ -o $@ $(BENCHLINK) -lm +gmm.o: gmm.cpp src/lib.rs Cargo.toml + cargo clean + ENZYME_PRINT_PERF=1 ENZYME_LOOSE_TYPES=1 cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm + # ENZYME_PRINT_MOD_AFTER=1 ENZYME_PRINT=1 ENZYME_LOOSE_TYPES=1 cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm &> comp.log + # ENZYME_PRINT_AA=1 ENZYME_PRINT=1 ENZYME_LOOSE_TYPES=1 cargo +enzyme rustc --release --lib --crate-type=staticlib --features=libm &> comp.log + clang++ $(LOAD) $(BENCH) gmm.cpp -I /usr/include/c++/11 -I/usr/include/x86_64-linux-gnu/c++/11 -O2 -o gmm.o -lpthread $(BENCHLINK) -lm $(dir)/benchmarks/ReverseMode/gmm/target/release/libgmmrs.a -L /usr/lib/gcc/x86_64-linux-gnu/11 results.json: gmm.o ./$^ diff --git a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp index 866059217b96..531ec18eb837 100644 --- a/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp +++ b/enzyme/benchmarks/ReverseMode/gmm/gmm.cpp @@ -42,8 +42,8 @@ extern "C" { double arr_max(int n, double const* x) { int i; - double m = x[0]; - for (i = 1; i < n; i++) + double m = x[2]; + for (i = 2; i < n; i++) { if (m < x[i]) { @@ -90,15 +90,16 @@ void subtract( double log_sum_exp(int n, double const* x) { int i; - double mx = arr_max(n, x); + double mx = arr_max(5, x); double semx = 0.0; - for (i = 0; i < n; i++) + for (i = 0; i < 5; i++) { - semx = semx + exp(x[i] - mx); + semx = semx + x[i]; } - return log(semx) + mx; + return (semx) + log(mx); + //return mx; } @@ -229,27 +230,10 @@ void gmm_objective( preprocess_qs(d, k, icf, &sum_qs[0], &Qdiags[0]); - double slse = 0.; - for (ix = 0; ix < n; ix++) - { - for (ik = 0; ik < k; ik++) - { - subtract(d, &x[ix * d], &means[ik * d], &xcentered[0]); - Qtimesx(d, &Qdiags[ik * d], &icf[ik * icf_sz + d], &xcentered[0], &Qxcentered[0]); - // two caches for qxcentered at idx 0 and at arbitrary index - main_term[ik] = alphas[ik] + sum_qs[ik] - 0.5 * sqnorm(d, &Qxcentered[0]); - } - - // storing cmp for max of main_term - // 2 x (0 and arbitrary) storing sub to exp - // storing sum for use in log - slse = slse + log_sum_exp(k, &main_term[0]); - } - //storing cmp of alphas double lse_alphas = log_sum_exp(k, alphas); - *err = CONSTANT + slse - n * lse_alphas + log_wishart_prior(d, k, wishart, &sum_qs[0], &Qdiags[0], icf); + *err = lse_alphas ;//+ log_wishart_prior(d, k, wishart, &sum_qs[0], &Qdiags[0], icf); free(Qdiags); free(sum_qs); diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs new file mode 100644 index 000000000000..f0f37c401d4e --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/src/lib.rs @@ -0,0 +1,188 @@ +#![feature(autodiff)] +use std::f64::consts::PI; + +#[cfg(feature = "libm")] +use libm::lgamma; + +#[cfg(not(feature = "libm"))] +mod cmath { + extern "C" { + pub fn lgamma(x: f64) -> f64; + } +} +#[cfg(not(feature = "libm"))] +#[inline] +fn lgamma(x: f64) -> f64 { + unsafe { cmath::lgamma(x) } +} + +#[no_mangle] +pub extern "C" fn rust_dgmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, dalphas: *mut f64, means: *const f64, dmeans: *mut f64, icf: *const f64, dicf: *mut f64, x: *const f64, wishart: *const Wishart, err: *mut f64, derr: *mut f64) { + let k = k as usize; + let n = n as usize; + let d = d as usize; + let alphas = unsafe { std::slice::from_raw_parts(alphas, k) }; + let means = unsafe { std::slice::from_raw_parts(means, k * d) }; + let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) }; + let x = unsafe { std::slice::from_raw_parts(x, n * d) }; + let wishart: Wishart = unsafe { *wishart }; + let mut my_err = unsafe { *err }; + + let d_alphas = unsafe { std::slice::from_raw_parts_mut(dalphas, k) }; + let d_means = unsafe { std::slice::from_raw_parts_mut(dmeans, k * d) }; + let d_icf = unsafe { std::slice::from_raw_parts_mut(dicf, k * d * (d + 1) / 2) }; + let mut my_derr = unsafe { *derr }; + + dgmm_objective(d, k, n, alphas, d_alphas, means, d_means, icf, d_icf, x, wishart.gamma, wishart.m, &mut my_err, &mut my_derr); + + unsafe { *err = my_err }; + unsafe { *derr = my_derr }; +} + +#[no_mangle] +pub extern "C" fn rust_gmm_objective(d: i32, k: i32, n: i32, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) { + let k = k as usize; + let n = n as usize; + let d = d as usize; + let alphas = unsafe { std::slice::from_raw_parts(alphas, k) }; + let means = unsafe { std::slice::from_raw_parts(means, k * d) }; + let icf = unsafe { std::slice::from_raw_parts(icf, k * d * (d + 1) / 2) }; + let x = unsafe { std::slice::from_raw_parts(x, n * d) }; + let wishart: Wishart = unsafe { *wishart }; + let mut my_err = unsafe { *err }; + gmm_objective(d, k, n, alphas, means, icf, x, wishart.gamma, wishart.m, &mut my_err); + unsafe { *err = my_err }; +} + + +//#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, Duplicated)] +//pub fn gmm_objective_c(d: usize, k: usize, n: usize, alphas: *const f64, means: *const f64, icf: *const f64, x: *const f64, wishart: *const Wishart, err: *mut f64) { +// gmm_objective(d, k, n, alphas, means, icf, x, wishart, &mut my_err); +//} + +#[autodiff(dgmm_objective, Reverse, Const, Const, Const, Duplicated, Duplicated, Duplicated, Const, Const, Const, Duplicated)] +pub fn gmm_objective(d: usize, k: usize, n: usize, alphas: &[f64], means: &[f64], icf: &[f64], x: &[f64], gamma: f64, m: i32, err: &mut f64) { + let wishart: Wishart = Wishart { gamma, m }; + //let wishart: Wishart = unsafe { *wishart }; + // let constant = -(n as f64) * d as f64 * 0.5 * (2.0 * PI).ln(); + let icf_sz = d * (d + 1) / 2; + let mut qdiags = vec![0.; d * k]; + let mut sum_qs = vec![0.; k]; + let mut xcentered = vec![0.; d]; + let mut qxcentered = vec![0.; d]; + let mut main_term = vec![0.; k]; + + preprocess_qs(d, k, icf, &mut sum_qs, &mut qdiags); + + for ix in 0..n { + for ik in 2..5 { + subtract(d, &x[ix as usize * d as usize..], &means[ik as usize * d as usize..], &mut xcentered); + qtimesx(d, &qdiags[ik as usize * d as usize..], &icf[ik as usize * icf_sz as usize + d as usize..], &xcentered, &mut qxcentered); + main_term[ik as usize] = alphas[ik as usize]; + } + + } + + let lse_alphas = log_sum_exp(k, alphas); + + let _lwp = { + let p = d; + let n = p + wishart.m as usize + 1; + let icf_sz = p * (p + 1) / 2; + + let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) - log_gamma_distrib(0.5 * n as f64, p as f64); + + let out = (0..k).map(|ik| { + let frobenius = sqnorm(&qdiags[ik * p as usize..][..p]) + sqnorm(&icf[ik * icf_sz as usize + p as usize..][..icf_sz -p]); + 0.5 * wishart.gamma * wishart.gamma * (frobenius) - (wishart.m as f64) * sum_qs[ik as usize] + }).sum::(); + + k as f64 * c + }; + //let lwp = log_wishart_prior(d, k, wishart, &sum_qs, &qdiags, icf); + + *err = lse_alphas; // + lwp; +} + +fn arr_max(n: usize, x: &[f64]) -> f64 { + let mut max = f64::NEG_INFINITY; + for i in 2..5 { + if max < x[i] { + max = x[i]; + } + } + max +} + + +fn preprocess_qs(d: usize, k: usize, icf: &[f64], sum_qs: &mut [f64], qdiags: &mut [f64]) { + let icf_sz = d * (d + 1) / 2; + let q = icf[13]; + for ik in 0..k { + sum_qs[ik as usize] = 2.7; + for id in 0..d { + sum_qs[ik as usize] = sum_qs[ik as usize] + q; + break; + } + } +} +fn subtract(d: usize, x: &[f64], y: &[f64], out: &mut [f64]) { + assert!(x.len() >= d); + assert!(y.len() >= d); + assert!(out.len() >= d); + for i in 0..d { + out[i] = 3.1; + } +} + +fn qtimesx(d: usize, q_diag: &[f64], ltri: &[f64], x: &[f64], out: &mut [f64]) { + assert!(out.len() >= d); + assert!(q_diag.len() >= d); + assert!(x.len() >= d); + for i in 0..d { + out[i] = 2.7; + } + + for i in 0..d { + let mut lparamsidx = i*(2*d-i-1)/2; + for j in i + 1..d { + out[j] = out[j] + ltri[lparamsidx] * 2.0; + lparamsidx += 1; + } + } +} + +fn log_sum_exp(n: usize, x: &[f64]) -> f64 { + let mx = arr_max(n, x); + let semx: f64 = x.iter().sum(); + semx + mx.ln() +} +#[inline(always)] +fn log_gamma_distrib(a: f64, p: f64) -> f64 { + 0.25 * p * (p - 1.) * PI.ln() + (1..=p as usize).map(|j| lgamma(a + 0.5 * (1. - j as f64))).sum::() +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Wishart { + pub gamma: f64, + pub m: i32, +} +#[cfg(we_inlined_it)] +fn log_wishart_prior(p: usize, k: usize, wishart: Wishart, sum_qs: &[f64], qdiags: &[f64], icf: &[f64]) -> f64 { + let n = p + wishart.m as usize + 1; + let icf_sz = p * (p + 1) / 2; + + let c = n as f64 * p as f64 * (wishart.gamma.ln() - 0.5 * 2f64.ln()) - log_gamma_distrib(0.5 * n as f64, p as f64); + + let out = (0..k).map(|ik| { + let frobenius = sqnorm(&qdiags[ik * p as usize..][..p]) + sqnorm(&icf[ik * icf_sz as usize + p as usize..][..icf_sz -p]); + 0.5 * wishart.gamma * wishart.gamma * (frobenius) - (wishart.m as f64) * sum_qs[ik as usize] + }).sum::(); + + out - k as f64 * c +} + +fn sqnorm(x: &[f64]) -> f64 { + x.iter().map(|x| x * x).sum() +} diff --git a/enzyme/benchmarks/ReverseMode/gmm/src/main.rs b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs new file mode 100644 index 000000000000..8f4357588ab8 --- /dev/null +++ b/enzyme/benchmarks/ReverseMode/gmm/src/main.rs @@ -0,0 +1,24 @@ +#![feature(autodiff)] +use gmmrs::{Wishart, dgmm_objective}; + +fn main() { + let d = 2; + let k = 2; + let n = 2; + let alphas = vec![0.5, 0.5]; + let means = vec![0., 0., 1., 1.]; + let icf = vec![1., 0., 1.]; + let x = vec![0., 0., 1., 1.]; + let wishart = Wishart { gamma: 1., m: 1 }; + let mut err = 0.; + let mut d_alphas = vec![0.; alphas.len()]; + let mut d_means = vec![0.; means.len()]; + let mut d_icf = vec![0.; icf.len()]; + let mut d_x = vec![0.; x.len()]; + let mut d_err = 0.; + let mut err2 = &mut err; + let mut d_err2 = &mut d_err; + let wishart2 = &wishart; + // pass as raw ptr: + dgmm_objective(d, k, n, alphas.as_ptr(), d_alphas.as_mut_ptr(), means.as_ptr(), d_means.as_mut_ptr(), icf.as_ptr(), d_icf.as_mut_ptr(), x.as_ptr(), wishart2 as *const Wishart, err2 as *mut f64, d_err2 as *mut f64); +} diff --git a/enzyme/benchmarks/lit.site.cfg.py.in b/enzyme/benchmarks/lit.site.cfg.py.in index 93937f9c62d3..2ef3c28b0ca9 100644 --- a/enzyme/benchmarks/lit.site.cfg.py.in +++ b/enzyme/benchmarks/lit.site.cfg.py.in @@ -49,21 +49,68 @@ config.substitutions.append(('%lli', config.llvm_tools_dir + "/lli" + (" --jit-k config.substitutions.append(('%opt', config.llvm_tools_dir + "/opt")) config.substitutions.append(('%llvmver', config.llvm_ver)) config.substitutions.append(('%FileCheck', config.llvm_tools_dir + "/FileCheck")) -config.substitutions.append(('%clang', config.llvm_tools_dir + "/clang")) -config.substitutions.append(('%loadEnzyme', '' - + (" --enable-new-pm=0" if int(config.llvm_ver) >= 13 else "") + +emopt = config.enzyme_obj_root + "/Enzyme/MLIR/enzymemlir-opt" +if len("@ENZYME_BINARY_DIR@") == 0: + emopt = os.path.dirname(os.path.abspath(__file__)) + "/../enzymemlir-opt" + +eclang = config.llvm_tools_dir + "/clang" +if len("@ENZYME_BINARY_DIR@") == 0: + eclang = os.path.dirname(os.path.abspath(__file__)) + "/../enzyme-clang" + resource = config.llvm_tools_dir + "/../clang/staging" + eclang += " -resource-dir " + resource + " " + eclang += "-I " + os.path.dirname(os.path.abspath(__file__)) + "/Integration" + +config.substitutions.append(('%eopt', emopt)) +config.substitutions.append(('%llvmver', config.llvm_ver)) +config.substitutions.append(('%FileCheck', config.llvm_tools_dir + "/FileCheck")) +config.substitutions.append(('%clang', eclang)) +config.substitutions.append(('%O0TBAA', "-O1 -Xclang -disable-llvm-passes")) + +oldPM = ((" --enable-new-pm=0" if int(config.llvm_ver) >= 13 else "") + ' -load=@ENZYME_BINARY_DIR@/Enzyme/LLVMEnzyme-' + config.llvm_ver + config.llvm_shlib_ext - + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "") - + ' -enzyme-preopt=0' - )) + + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) +newPM = ((" --enable-new-pm=1" if int(config.llvm_ver) in (12,13) else "") + + ' -load-pass-plugin=@ENZYME_BINARY_DIR@/Enzyme/LLVMEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + + ' -load=@ENZYME_BINARY_DIR@/Enzyme/LLVMEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) +if len("@ENZYME_BINARY_DIR@") == 0: + oldPM = ((" --enable-new-pm=0" if int(config.llvm_ver) >= 13 else "") + + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) + newPM = ((" --enable-new-pm=1" if int(config.llvm_ver) in (12,13) else "") + + (" --enzyme-attributor=0" if int(config.llvm_ver) >= 13 else "")) + +oldPMOP = oldPM +newPMOP = newPM +if int(config.llvm_ver) == 16: + newPM += " -opaque-pointers=0" + oldPM += " -opaque-pointers=0" + +config.substitutions.append(('%loadEnzyme', oldPM if int(config.llvm_ver) < 16 else newPM)) +config.substitutions.append(('%newLoadEnzyme', newPM)) +config.substitutions.append(('%OPloadEnzyme', oldPMOP if int(config.llvm_ver) < 16 else newPMOP)) +config.substitutions.append(('%OPnewLoadEnzyme', newPMOP)) +config.substitutions.append(('%enzyme', ('-enzyme' if int(config.llvm_ver) < 16 else '-passes="enzyme"'))) +config.substitutions.append(('%simplifycfg', ("simplify-cfg" if int(config.llvm_ver) < 11 else "simplifycfg"))) +config.substitutions.append(('%loopmssa', ("loop" if int(config.llvm_ver) < 11 else "loop-mssa"))) + config.substitutions.append(('%loadBC', '' + ' @ENZYME_BINARY_DIR@/BCLoad/BCPass-' + config.llvm_ver + config.llvm_shlib_ext )) config.substitutions.append(('%BClibdir', '@ENZYME_SOURCE_DIR@/bclib/')) -config.substitutions.append(('%loadClangEnzyme', '' - + (" -fno-experimental-new-pass-manager" if int(config.llvm_ver) >= 13 else "") - + ' -Xclang -load -Xclang @ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext - )) + +oldPM = (((" -fno-experimental-new-pass-manager" if int(config.llvm_ver) < 14 else "-flegacy-pass-manager") if int(config.llvm_ver) >= 13 else "") + + ' -Xclang -load -Xclang @ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext) +newPM = ((" -fexperimental-new-pass-manager" if int(config.llvm_ver) < 13 else "") + + ' -fpass-plugin=@ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext + + ' -Xclang -load -Xclang @ENZYME_BINARY_DIR@/Enzyme/ClangEnzyme-' + config.llvm_ver + config.llvm_shlib_ext) + +if len("@ENZYME_BINARY_DIR@") == 0: + oldPM = ((" -fno-experimental-new-pass-manager" if int(config.llvm_ver) < 14 else "-flegacy-pass-manager") if int(config.llvm_ver) >= 13 else "") + newPM = (" -fexperimental-new-pass-manager" if int(config.llvm_ver) < 13 else "") + +config.substitutions.append(('%loadClangEnzyme', oldPM if int(config.llvm_ver) < 15 else newPM)) +config.substitutions.append(('%newLoadClangEnzyme', newPM)) # Let the main config do the real work. lit_config.load_config(config, "@ENZYME_SOURCE_DIR@/benchmarks/lit.cfg.py")