Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bench gmm: correct gradients so long as we skip log_gamma_distrib #1826

Open
wants to merge 30 commits into
base: rust-bench
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b2a172c
adding gmm
ZuseZ4 Mar 19, 2024
436eafe
working C too
ZuseZ4 Mar 25, 2024
1d0add9
Delete enzyme/benchmarks/ReverseMode/gmmrs/src/main.rs2
ZuseZ4 Mar 26, 2024
360a10d
rust setup
wsmoses Mar 28, 2024
2e80893
add files
wsmoses Mar 28, 2024
b3610a1
improve makefile and fix c ffi
ZuseZ4 Mar 29, 2024
8562585
maybe needed? pthread for cmake
ZuseZ4 Mar 30, 2024
9d34d0e
bench gmm: use path relative to Makefile
jedbrown Mar 30, 2024
3708bf9
Fix byref issue for rust abi
wsmoses Mar 31, 2024
f6cf01f
Add primal bench/test
wsmoses Mar 31, 2024
10805c4
fix math
ZuseZ4 Mar 31, 2024
83bdb30
write into return var
ZuseZ4 Mar 31, 2024
3f7274b
Cleanup gmm config
wsmoses Mar 31, 2024
f77f533
bench gmm: make cmath::lgamma with libm as an optional feature
jedbrown Mar 31, 2024
4a52983
oxidize - more noalias
ZuseZ4 Mar 31, 2024
bb145a0
reduce caching
ZuseZ4 Mar 31, 2024
02c9cf6
bench gmm: makefile dep on Cargo.toml, split targets
jedbrown Mar 31, 2024
fd8edf1
revert cmake pthread since only needed for Rust
ZuseZ4 Mar 31, 2024
3ab2e12
bench gmm: fix primal (sqnorm length matters)
jedbrown Mar 31, 2024
d97ff33
bench gmm: quash rust warnings
jedbrown Mar 31, 2024
2608c93
adding ba benchmark
ZuseZ4 Mar 31, 2024
1c9a124
Benchmark ba
wsmoses Apr 1, 2024
b045a50
bench gmm: correct gradients so long as we skip log_gamma_distrib
jedbrown Apr 2, 2024
a309779
make it fail
wsmoses Apr 2, 2024
d9ffb41
reduce more
wsmoses Apr 2, 2024
48389a5
red
wsmoses Apr 2, 2024
9a586cb
continue to reduce
wsmoses Apr 2, 2024
3ceebba
more minimize
wsmoses Apr 2, 2024
4f43790
reduce more
wsmoses Apr 2, 2024
5e00f84
further reduce
wsmoses Apr 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 148 additions & 2 deletions enzyme/benchmarks/ReverseMode/adbench/ba.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ extern "C" {
double* reproj_err,
double* w_err
);

void rust2_ba_objective(
int n,
int m,
int p,
double const* cams,
double const* X,
double const* w,
int const* obs,
double const* feats,
double* reproj_err,
double* w_err
);

void dcompute_reproj_error(
double const* cam,
Expand Down Expand Up @@ -169,6 +182,20 @@ extern "C" {
);

void adept_compute_zach_weight_error(double const* w, double* dw, double* err, double* derr);

void rust_dcompute_reproj_error(
double const* cam,
double * dcam,
double const* X,
double * dX,
double const* w,
double * wb,
double const* feat,
double *err,
double *derr
);

void rust_dcompute_zach_weight_error(double const* w, double* dw, double* err, double* derr);
}

void read_ba_instance(const string& fn,
Expand Down Expand Up @@ -486,9 +513,9 @@ int main(const int argc, const char* argv[]) {
gettimeofday(&start, NULL);
calculate_jacobian<dcompute_reproj_error, dcompute_zach_weight_error>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme combined %0.6f\n", tdiff(&start, &end));
printf("Enzyme c++ combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Enzyme combined";
enzyme["name"] = "Enzyme c++ combined";
enzyme["runtime"] = tdiff(&start, &end);
for(unsigned i=0; i<5; i++) {
printf("%f ", result.J.vals[i]);
Expand All @@ -499,6 +526,125 @@ int main(const int argc, const char* argv[]) {
}

}

{
struct BAInput input;
read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats);

struct BAOutput result = {
std::vector<double>(2 * input.p),
std::vector<double>(input.p),
BASparseMat(input.n, input.m, input.p)
};


{
struct timeval start, end;
gettimeofday(&start, NULL);
ba_objective(
input.n,
input.m,
input.p,
input.cams.data(),
input.X.data(),
input.w.data(),
input.obs.data(),
input.feats.data(),
result.reproj_err.data(),
result.w_err.data()
);
gettimeofday(&end, NULL);
printf("primal c++ t=%0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "primal c++";
enzyme["runtime"] = tdiff(&start, &end);
for(unsigned i=0; i<5; i++) {
printf("%f ", result.reproj_err[i]);
enzyme["result"].push_back(result.reproj_err[i]);
}
for(unsigned i=0; i<5; i++) {
printf("%f ", result.w_err[i]);
enzyme["result"].push_back(result.w_err[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}


{
struct BAInput input;
read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats);

struct BAOutput result = {
std::vector<double>(2 * input.p),
std::vector<double>(input.p),
BASparseMat(input.n, input.m, input.p)
};
{

struct timeval start, end;
gettimeofday(&start, NULL);
rust2_ba_objective(
input.n,
input.m,
input.p,
input.cams.data(),
input.X.data(),
input.w.data(),
input.obs.data(),
input.feats.data(),
result.reproj_err.data(),
result.w_err.data()
);
gettimeofday(&end, NULL);
printf("primal rust t=%0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "primal rust";
enzyme["runtime"] = tdiff(&start, &end);
for(unsigned i=0; i<5; i++) {
printf("%f ", result.reproj_err[i]);
enzyme["result"].push_back(result.reproj_err[i]);
}
for(unsigned i=0; i<5; i++) {
printf("%f ", result.w_err[i]);
enzyme["result"].push_back(result.w_err[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}
}

{

struct BAInput input;
read_ba_instance("data/" + path, input.n, input.m, input.p, input.cams, input.X, input.w, input.obs, input.feats);

struct BAOutput result = {
std::vector<double>(2 * input.p),
std::vector<double>(input.p),
BASparseMat(input.n, input.m, input.p)
};

{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_dcompute_reproj_error, rust_dcompute_zach_weight_error>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Enzyme rust combined";
enzyme["runtime"] = tdiff(&start, &end);
for(unsigned i=0; i<5; i++) {
printf("%f ", result.J.vals[i]);
enzyme["result"].push_back(result.J.vals[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}

}

test_suite["llvm-version"] = __clang_version__;
test_suite["mode"] = "ReverseMode";
test_suite["batch-size"] = 1;
Expand Down
108 changes: 101 additions & 7 deletions enzyme/benchmarks/ReverseMode/adbench/gmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ struct GMMParameters {
};

extern "C" {
void gmm_objective(
int d,
int k,
int n,
double const* alphas,
double const* means,
double const* icf,
double const* x,
Wishart wishart,
double* err
);
void dgmm_objective(int d, int k, int n, const double *alphas, double *
alphasb, const double *means, double *meansb, const double *icf,
double *icfb, const double *x, Wishart wishart, double *err, double *
Expand All @@ -47,6 +58,15 @@ extern "C" {
alphasb, const double *means, double *meansb, const double *icf,
double *icfb, const double *x, Wishart wishart, double *err, double *
errb);

void rust_dgmm_objective(int d, int k, int n, const double *alphas, double *
alphasb, const double *means, double *meansb, const double *icf,
double *icfb, const double *x, Wishart &wishart, double *err, double *
errb);

void rust_gmm_objective(int d, int k, int n, const double *alphas,
const double *means, const double *icf,
const double *x, Wishart &wishart, double *err);
}

void read_gmm_instance(const string& fn,
Expand Down Expand Up @@ -123,10 +143,7 @@ void read_gmm_instance(const string& fn,
fclose(fid);
}

typedef void(*deriv_t)(int d, int k, int n, const double *alphas, double *alphasb, const double *means, double *meansb, const double *icf,
double *icfb, const double *x, Wishart wishart, double *err, double *errb);

template<deriv_t deriv>
template<auto deriv>
void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result)
{
double* alphas_gradient_part = result.gradient.data();
Expand Down Expand Up @@ -159,6 +176,25 @@ void calculate_jacobian(struct GMMInput &input, struct GMMOutput &result)
);
}

template<auto deriv>
double primal(struct GMMInput &input)
{
double tmp = 0.0; // stores fictive result
// (Tapenade doesn't calculate an original function in reverse mode)
deriv(
input.d,
input.k,
input.n,
input.alphas.data(),
input.means.data(),
input.icf.data(),
input.x.data(),
input.wishart,
&tmp
);
return tmp;
}

int main(const int argc, const char* argv[]) {
printf("starting main\n");

Expand All @@ -167,9 +203,11 @@ int main(const int argc, const char* argv[]) {

std::vector<std::string> paths;// = { "1k/gmm_d10_K100.txt" };

getTests(paths, "data/1k", "1k/");
getTests(paths, "data/2.5k", "2.5k/");
getTests(paths, "data/10k", "10k/");
//getTests(paths, "data/1k", "1k/");
//getTests(paths, "data/2.5k", "2.5k/");
//getTests(paths, "data/10k", "10k/");
//paths.push_back("1k/gmm_d128_K100.txt");
paths.push_back("1k/gmm_d2_K5.txt");

std::ofstream jsonfile("results.json", std::ofstream::trunc);
json test_results;
Expand Down Expand Up @@ -257,6 +295,7 @@ int main(const int argc, const char* argv[]) {
gettimeofday(&start, NULL);
calculate_jacobian<dgmm_objective>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme c++ combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Enzyme combined";
enzyme["runtime"] = tdiff(&start, &end);
Expand All @@ -269,6 +308,61 @@ int main(const int argc, const char* argv[]) {
test_suite["tools"].push_back(enzyme);
}

}

{

struct GMMInput input;
read_gmm_instance("data/" + path, &input.d, &input.k, &input.n,
input.alphas, input.means, input.icf, input.x, input.wishart, params.replicate_point);

int Jcols = (input.k * (input.d + 1) * (input.d + 2)) / 2;

struct GMMOutput result = { 0, std::vector<double>(Jcols) };

{
struct timeval start, end;
gettimeofday(&start, NULL);
auto res = primal<gmm_objective>(input);
gettimeofday(&end, NULL);
printf("c++ primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res);

json primal;
primal["name"] = "C++ primal";
primal["runtime"] = tdiff(&start, &end);
primal["result"].push_back(res);
test_suite["tools"].push_back(primal);
}
{
struct timeval start, end;
gettimeofday(&start, NULL);
auto res = primal<rust_gmm_objective>(input);
gettimeofday(&end, NULL);
printf("rust primal combined t=%0.6f, err=%f\n", tdiff(&start, &end), res);
json primal;
primal["name"] = "Rust primal";
primal["runtime"] = tdiff(&start, &end);
primal["result"].push_back(res);
test_suite["tools"].push_back(primal);
}
{
struct timeval start, end;
gettimeofday(&start, NULL);
calculate_jacobian<rust_dgmm_objective>(input, result);
gettimeofday(&end, NULL);
printf("Enzyme rust combined %0.6f\n", tdiff(&start, &end));
json enzyme;
enzyme["name"] = "Rust Enzyme combined";
enzyme["runtime"] = tdiff(&start, &end);
for (unsigned i = result.gradient.size() - 5;
i < result.gradient.size(); i++) {
printf("%f ", result.gradient[i]);
enzyme["result"].push_back(result.gradient[i]);
}
printf("\n");
test_suite["tools"].push_back(enzyme);
}

}
test_suite["llvm-version"] = __clang_version__;
test_suite["mode"] = "ReverseMode";
Expand Down
16 changes: 16 additions & 0 deletions enzyme/benchmarks/ReverseMode/ba/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions enzyme/benchmarks/ReverseMode/ba/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "bars"
version = "0.1.0"
edition = "2021"


[lib]
crate-type = ["cdylib"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[profile.release]
lto = "fat"

[profile.dev]
lto = "fat"

[dependencies]
libm = { version = "0.2.8", optional = true }
Loading