From d30d790a8bbc2dc2641982190590b4d9fdb48f40 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 22 Nov 2024 12:31:29 +0100 Subject: [PATCH] wip: get new baselines --- .github/workflows/regression_tests.yml | 21 ++++++---- tests/test_regression.py | 54 ++++++++++++-------------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml index 78fabd09..5a98e73b 100644 --- a/.github/workflows/regression_tests.yml +++ b/.github/workflows/regression_tests.yml @@ -27,14 +27,19 @@ jobs: python -m pip install --upgrade pip pip install -e ".[dev]" + # - name: Run benchmarks and compare to baseline + # if: github.event.pull_request.base.ref == 'main' + # run: | + # # Check if regression test results exist in main branch + # if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then + # git checkout main tests/regression_test_baselines.json + # else + # echo "No regression test results found in main branch" + # fi + # pytest -m regression + # git checkout + - name: Run benchmarks and compare to baseline if: github.event.pull_request.base.ref == 'main' run: | - # # Check if regression test results exist in main branch - # if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then - # git checkout main tests/regression_test_baselines.json - # else - # echo "No regression test results found in main branch" - # fi - pytest -m regression - # git checkout \ No newline at end of file + pytest -m regression \ No newline at end of file diff --git a/tests/test_regression.py b/tests/test_regression.py index 1b14a5f3..b7684f47 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -202,13 +202,13 @@ def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0): ( # Test a single SWC cell with both solvers. pytest.param(1, False, False, 0.0, "jaxley.stone"), - # pytest.param(1, False, False, 0.0, "jax.sparse"), - # # Test a network of SWC cells with both solvers. - # pytest.param(10, False, True, 0.1, "jaxley.stone"), - # pytest.param(10, False, True, 0.1, "jax.sparse"), - # # Test a larger network of smaller neurons with both solvers. - # pytest.param(1000, True, True, 0.001, "jaxley.stone"), - # pytest.param(1000, True, True, 0.001, "jax.sparse"), + pytest.param(1, False, False, 0.0, "jax.sparse"), + # Test a network of SWC cells with both solvers. + pytest.param(10, False, True, 0.1, "jaxley.stone"), + pytest.param(10, False, True, 0.1, "jax.sparse"), + # Test a larger network of smaller neurons with both solvers. + pytest.param(1000, True, True, 0.001, "jaxley.stone"), + pytest.param(1000, True, True, 0.001, "jax.sparse"), ), ) @compare_to_baseline(baseline_iters=3) @@ -219,41 +219,37 @@ def test_runtime( connection_prob: float, voltage_solver: str, ): - import time delta_t = 0.025 t_max = 100.0 - # def simulate(params): - # return jx.integrate( - # net, - # params=params, - # t_max=t_max, - # delta_t=delta_t, - # voltage_solver=voltage_solver, - # ) + def simulate(params): + return jx.integrate( + net, + params=params, + t_max=t_max, + delta_t=delta_t, + voltage_solver=voltage_solver, + ) runtimes = {} start_time = time.time() - # net, params = build_net( - # num_cells, - # artificial=artificial, - # connect=connect, - # connection_prob=connection_prob, - # ) - time.sleep(0.1) + net, params = build_net( + num_cells, + artificial=artificial, + connect=connect, + connection_prob=connection_prob, + ) runtimes["build_time"] = time.time() - start_time - # jitted_simulate = jit(simulate) + jitted_simulate = jit(simulate) start_time = time.time() - time.sleep(0.31) - # _ = jitted_simulate(params).block_until_ready() + _ = jitted_simulate(params).block_until_ready() runtimes["compile_time"] = time.time() - start_time - # params[0]["radius"] = params[0]["radius"].at[0].set(0.5) + params[0]["radius"] = params[0]["radius"].at[0].set(0.5) start_time = time.time() - # _ = jitted_simulate(params).block_until_ready() - time.sleep(0.21) + _ = jitted_simulate(params).block_until_ready() runtimes["run_time"] = time.time() - start_time return runtimes # @compare_to_baseline decorator will compare this to the baseline