diff --git a/.github/workflows/regression_tests.yml b/.github/workflows/regression_tests.yml
new file mode 100644
index 00000000..3509be48
--- /dev/null
+++ b/.github/workflows/regression_tests.yml
@@ -0,0 +1,39 @@
+# .github/workflows/regression_tests.yml
+name: Regression Tests
+
+on:
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ regression_tests:
+ name: regression_tests
+ runs-on: ubuntu-20.04
+
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ lfs: true
+ fetch-depth: 0 # This ensures we can checkout main branch too
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ architecture: 'x64'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e ".[dev]"
+
+ - name: Run benchmarks and compare to baseline
+ if: github.event.pull_request.base.ref == 'main'
+ run: |
+ # Check if regression test results exist in main branch
+ if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then
+ git checkout main tests/regression_test_baselines.json
+ else
+ echo "No regression test results found in main branch"
+ fi
+ pytest -m regression
\ No newline at end of file
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 3eb90b0a..1750f290 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -39,4 +39,4 @@ jobs:
- name: Test with pytest
run: |
pip install pytest pytest-cov
- pytest tests/ --cov=jaxley --cov-report=xml
+ pytest tests/ -m "not regression" --cov=jaxley --cov-report=xml
diff --git a/.github/workflows/update_regression_baseline.yml b/.github/workflows/update_regression_baseline.yml
new file mode 100644
index 00000000..c2d6bf65
--- /dev/null
+++ b/.github/workflows/update_regression_baseline.yml
@@ -0,0 +1,141 @@
+# .github/workflows/update_regression_tests.yml
+
+# for details on triggering a workflow from a comment, see:
+# https://dev.to/zirkelc/trigger-github-workflow-for-comment-on-pull-request-45l2
+name: Update Regression Baseline
+
+on:
+ issue_comment: # trigger from comment; event runs on the default branch
+ types: [created]
+
+jobs:
+ update_regression_tests:
+ name: update_regression_tests
+ runs-on: ubuntu-20.04
+ # Trigger from a comment that contains '/update_regression_baselines'
+ if: github.event.issue.pull_request && contains(github.event.comment.body, '/update_regression_baselines')
+ # workflow needs permissions to write to the PR
+ permissions:
+ contents: write
+ pull-requests: write
+ issues: read
+
+ steps:
+ - name: Create initial status comment
+ uses: actions/github-script@v7
+ id: initial-comment
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ script: |
+ const response = await github.rest.issues.createComment({
+ issue_number: context.issue.number,
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ body: '## Updating Regression Baselines\n⏳ Workflow is currently running...'
+ });
+ return response.data.id;
+
+ - name: Check if PR is from fork
+ id: check-fork
+ uses: actions/github-script@v7
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ script: |
+ const pr = await github.rest.pulls.get({
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ pull_number: context.issue.number
+ });
+ return pr.data.head.repo.fork;
+
+ - name: Get PR branch
+ uses: xt0rted/pull-request-comment-branch@v3
+ id: comment-branch
+ with:
+ repo_token: ${{ secrets.GITHUB_TOKEN }}
+
+ - name: Checkout PR branch
+ uses: actions/checkout@v3
+ with:
+ ref: ${{ steps.comment-branch.outputs.head_sha }} # using head_sha vs. head_ref makes this work for forks
+ lfs: true
+ fetch-depth: 0 # This ensures we can checkout main branch too
+
+ - uses: actions/setup-python@v4
+ with:
+ python-version: '3.10'
+ architecture: 'x64'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e ".[dev]"
+
+ - name: Update baseline
+ id: update-baseline
+ run: |
+ git config --global user.name '${{ github.event.comment.user.login }}'
+ git config --global user.email '${{ github.event.comment.user.login }}@users.noreply.github.com'
+ # Check if regression test results exist in main branch
+ if [ -f 'git cat-file -e main:tests/regression_test_baselines.json' ]; then
+ git checkout main tests/regression_test_baselines.json
+ else
+ echo "No regression test results found in main branch"
+ fi
+ NEW_BASELINE=1 pytest -m regression
+
+ # Pushing to the PR branch does not work if the PR is initiated from a fork. This is
+ # because the GITHUB_TOKEN has read-only access by default for workflows triggered by
+ # fork PRs. Hence we have to create a new PR to update the baseline (see below).
+ - name: Commit and push to PR branch (non-fork)
+ # Only run if baseline generation succeeded
+ if: success() && steps.update-baseline.outcome == 'success' && !fromJson(steps.check-fork.outputs.result)
+ run: |
+ git add -f tests/regression_test_baselines.json # since it's in .gitignore
+ git commit -m "Update regression test baselines"
+ git push origin HEAD:${{ steps.comment-branch.outputs.head_ref }} # head_ref will probably not work for forks!
+
+ - name: Create PR with updates (fork)
+ if: success() && steps.update-baseline.outcome == 'success' && fromJson(steps.check-fork.outputs.result)
+ uses: peter-evans/create-pull-request@v5
+ id: create-pr
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ commit-message: Update regression test baselines
+ title: 'Update regression test baselines'
+ branch: regression-baseline-update-${{ github.event.issue.number }}
+ base: ${{ steps.comment-branch.outputs.head_ref }}
+
+ - name: Update comment with results
+ uses: actions/github-script@v7
+ if: always() # Run this step even if previous steps fail
+ with:
+ github-token: ${{ secrets.GITHUB_TOKEN }}
+ script: |
+ const fs = require('fs');
+ let status = '${{ steps.update-baseline.outcome }}' === 'success' ? '✅' : '❌';
+ let message = '## Regression Baseline Update\n' + status + ' Process completed\n\n';
+
+ try {
+ const TestReport = fs.readFileSync('tests/regression_test_report.txt', 'utf8');
+ message += '```\n' + TestReport + '\n```\n\n';
+
+ // Add information about where the changes were pushed
+ if ('${{ steps.update-baseline.outcome }}' === 'success') {
+ if (!${{ fromJson(steps.check-fork.outputs.result) }}) {
+ message += '✨ Changes have been pushed directly to this PR\n';
+ } else {
+ const prNumber = '${{ steps.create-pr.outputs.pull-request-number }}';
+ message += `✨ Changes have been pushed to a new PR #${prNumber} because this PR is from a fork\n`;
+ }
+ }
+ } catch (error) {
+ message += '⚠️ No test report was generated\n';
+ }
+
+ await github.rest.issues.updateComment({
+ comment_id: ${{ steps.initial-comment.outputs.result }},
+ owner: context.repo.owner,
+ repo: context.repo.repo,
+ body: message
+ });
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 6162a95b..d5638eb6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -55,6 +55,8 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
+tests/regression_test_results.json
+tests/regression_test_baselines.json
# Translations
*.mo
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 21020bed..c678ab75 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -6,7 +6,11 @@
```python
net.record("i_IonotropicSynapse")
```
-
+- Add regression tests and supporting workflows for maintaining baselines (#475, @jnsbck).
+ - PRs now trigger both tests and regression tests.
+ - Baselines are maintained in the main branch.
+ - Regression tests can be done locally by running `NEW_BASELINE=1 pytest -m regression` i.e. on `main` and then `pytest -m regression` on `feature`, which will produce a test report (printed to the console and saved to .txt).
+ - If a PR introduces new baseline tests or reduces runtimes, then a new baseline can be created by commenting "/update_regression_baselines" on the PR.
# 0.5.0
diff --git a/pyproject.toml b/pyproject.toml
index 35eb11cd..4a81bbe6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -69,6 +69,7 @@ dev = [
[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (T > 10s)",
+ "regression: marks regression tests",
]
[tool.isort]
diff --git a/tests/conftest.py b/tests/conftest.py
index dad1c4a5..afe41e30 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,7 @@
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see
+import json
import os
from copy import deepcopy
from typing import Optional
@@ -9,6 +10,7 @@
import jaxley as jx
from jaxley.synapses import IonotropicSynapse
+from tests.test_regression import generate_regression_report, load_json
@pytest.fixture(scope="session")
@@ -202,3 +204,44 @@ def get_or_compute_swc2jaxley_params(
yield get_or_compute_swc2jaxley_params
params = {}
+
+
+@pytest.fixture(scope="session", autouse=True)
+def print_session_report(request, pytestconfig):
+ """Cleanup a testing directory once we are finished."""
+ NEW_BASELINE = os.environ["NEW_BASELINE"] if "NEW_BASELINE" in os.environ else 0
+
+ dirname = os.path.dirname(__file__)
+ baseline_fname = os.path.join(dirname, "regression_test_baselines.json")
+ results_fname = os.path.join(dirname, "regression_test_results.json")
+
+ collected_regression_tests = [
+ item for item in request.session.items if item.get_closest_marker("regression")
+ ]
+
+ def update_baseline():
+ if NEW_BASELINE:
+ results = load_json(results_fname)
+ with open(baseline_fname, "w") as f:
+ json.dump(results, f, indent=2)
+ os.remove(results_fname)
+
+ def print_regression_report():
+ baselines = load_json(baseline_fname)
+ results = load_json(results_fname)
+
+ report = generate_regression_report(baselines, results)
+ # "No baselines found. Run `git checkout main;UPDATE_BASELINE=1 pytest -m regression; git checkout -`"
+ with open(dirname + "/regression_test_report.txt", "w") as f:
+ f.write(report)
+
+ # the following allows to print the report to the console despite pytest
+ # capturing the output and without specifying the "-s" flag
+ capmanager = request.config.pluginmanager.getplugin("capturemanager")
+ with capmanager.global_and_fixture_disabled():
+ print("\n\n\nRegression Test Report\n----------------------\n")
+ print(report)
+
+ if len(collected_regression_tests) > 0:
+ request.addfinalizer(update_baseline)
+ request.addfinalizer(print_regression_report)
diff --git a/tests/test_regression.py b/tests/test_regression.py
new file mode 100644
index 00000000..a1ccafe1
--- /dev/null
+++ b/tests/test_regression.py
@@ -0,0 +1,254 @@
+# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
+# licensed under the Apache License Version 2.0, see
+
+import hashlib
+import json
+import os
+import time
+from functools import wraps
+
+import numpy as np
+import pytest
+from jax import jit
+
+import jaxley as jx
+from jaxley.channels import HH
+from jaxley.connect import sparse_connect
+from jaxley.synapses import IonotropicSynapse
+
+# Every runtime test needs to have the following structure:
+#
+# @compare_to_baseline()
+# def test_runtime_of_x(**kwargs) -> Dict:
+# t1 = time.time()
+# time.sleep(0.1)
+# # do something
+# t2 = time.time()
+# # do something else
+# t3 = time.time()
+# return {"sth": t2-t1, sth_else: t3-t2}
+
+# The test should return a dictionary with the runtime of each part of the test.
+# This way the runtimes of different parts of the code can be compared to each other.
+
+# The @compare_to_baseline decorator will compare the runtime of the test to the baseline
+# and raise an assertion error if the runtime is significanlty slower than the baseline.
+# The decorator will also save the runtime of the test to a database, which will be used
+# to generate a report comparing the runtime of the tests to the baseline. The database
+# takes into account the input_kwargs of the test, the name of the test and the runtimes
+# of each part.
+
+
+def load_json(fpath):
+ dct = {}
+ if os.path.exists(fpath):
+ with open(fpath, "r") as f:
+ dct = json.load(f)
+ return dct
+
+
+pytestmark = pytest.mark.regression # mark all tests as regression tests in this file
+NEW_BASELINE = os.environ["NEW_BASELINE"] if "NEW_BASELINE" in os.environ else 0
+dirname = os.path.dirname(__file__)
+fpath_baselines = os.path.join(dirname, "regression_test_baselines.json")
+fpath_results = os.path.join(dirname, "regression_test_results.json")
+
+tolerance = 0.2
+
+baselines = load_json(fpath_baselines)
+with open(fpath_results, "w") as f: # clear previous results
+ f.write("{}")
+
+
+def generate_regression_report(base_results, new_results):
+ """Compare two sets of benchmark results and generate a diff report."""
+ report = []
+ for key in new_results:
+ new_data = new_results[key]
+ base_data = base_results.get(key)
+ kwargs = ", ".join([f"{k}={v}" for k, v in new_data["input_kwargs"].items()])
+ func_name = new_data["test_name"]
+ func_signature = f"{func_name}({kwargs})"
+
+ new_runtimes = new_data["runtimes"]
+ base_runtimes = (
+ {k: None for k in new_data.keys()}
+ if base_data is None
+ else base_data["runtimes"]
+ )
+
+ report.append(func_signature)
+ for key, new_time in new_runtimes.items():
+ base_time = base_runtimes.get(key)
+ diff = None if base_time is None else ((new_time - base_time) / base_time)
+
+ status = ""
+ if diff is None:
+ status = "🆕"
+ elif diff > tolerance:
+ status = "🔴"
+ elif diff < 0:
+ status = "🟢"
+ else:
+ status = "⚪"
+
+ time_str = (
+ f"({new_time:.3f}s)"
+ if diff is None
+ else f"({diff:+.2%} vs {base_time:.3f}s)"
+ )
+ report.append(f"{status} {key}: {time_str}.")
+ report.append("")
+
+ return "\n".join(report)
+
+
+def generate_unique_key(d):
+ # Generate a unique key for each test case. Makes it possible to compare tests
+ # with different input_kwargs.
+ hash_obj = hashlib.sha256(bytes(json.dumps(d, sort_keys=True), encoding="utf-8"))
+ hash = hash_obj.hexdigest()
+ return str(hash)
+
+
+def append_to_json(fpath, test_name, input_kwargs, runtimes):
+ header = {"test_name": test_name, "input_kwargs": input_kwargs}
+ data = {generate_unique_key(header): {**header, "runtimes": runtimes}}
+
+ # Save data to a JSON file
+ result_data = load_json(fpath)
+ result_data.update(data)
+
+ with open(fpath, "w") as f:
+ json.dump(result_data, f, indent=2)
+
+
+class compare_to_baseline:
+ def __init__(self, baseline_iters=3, test_iters=1):
+ self.baseline_iters = baseline_iters
+ self.test_iters = test_iters
+
+ def __call__(self, func):
+ @wraps(func) # ensures kwargs exposed to pytest
+ def test_wrapper(**kwargs):
+ header = {"test_name": func.__name__, "input_kwargs": kwargs}
+ key = generate_unique_key(header)
+
+ runs = []
+ num_iters = self.baseline_iters if NEW_BASELINE else self.test_iters
+ for _ in range(num_iters):
+ runtimes = func(**kwargs)
+ runs.append(runtimes)
+ runtimes = {k: np.mean([d[k] for d in runs]) for k in runs[0]}
+
+ append_to_json(
+ fpath_results, header["test_name"], header["input_kwargs"], runtimes
+ )
+
+ if not NEW_BASELINE:
+ assert key in baselines, f"No basline found for {header}"
+ func_baselines = baselines[key]["runtimes"]
+ for key, baseline in func_baselines.items():
+ diff = (
+ float("nan")
+ if np.isclose(baseline, 0)
+ else (runtimes[key] - baseline) / baseline
+ )
+ assert runtimes[key] <= baseline * (
+ 1 + tolerance
+ ), f"{key} is {diff:.2%} slower than the baseline."
+
+ return test_wrapper
+
+
+def build_net(num_cells, artificial=True, connect=True, connection_prob=0.0):
+ _ = np.random.seed(1) # For sparse connectivity matrix.
+
+ if artificial:
+ comp = jx.Compartment()
+ branch = jx.Branch(comp, 2)
+ depth = 3
+ parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
+ cell = jx.Cell(branch, parents=parents)
+ else:
+ dirname = os.path.dirname(__file__)
+ fname = os.path.join(dirname, "swc_files", "morph.swc")
+ cell = jx.read_swc(fname, nseg=4)
+ net = jx.Network([cell for _ in range(num_cells)])
+
+ # Channels.
+ net.insert(HH())
+
+ # Synapses.
+ if connect:
+ sparse_connect(
+ net.cell("all"), net.cell("all"), IonotropicSynapse(), connection_prob
+ )
+
+ # Recordings.
+ net[0, 1, 0].record(verbose=False)
+
+ # Trainables.
+ net.make_trainable("radius", verbose=False)
+ params = net.get_parameters()
+
+ net.to_jax()
+ return net, params
+
+
+@pytest.mark.parametrize(
+ "num_cells, artificial, connect, connection_prob, voltage_solver",
+ (
+ # Test a single SWC cell with both solvers.
+ pytest.param(1, False, False, 0.0, "jaxley.stone"),
+ pytest.param(1, False, False, 0.0, "jax.sparse"),
+ # Test a network of SWC cells with both solvers.
+ pytest.param(10, False, True, 0.1, "jaxley.stone"),
+ pytest.param(10, False, True, 0.1, "jax.sparse"),
+ # Test a larger network of smaller neurons with both solvers.
+ pytest.param(1000, True, True, 0.001, "jaxley.stone"),
+ pytest.param(1000, True, True, 0.001, "jax.sparse"),
+ ),
+)
+@compare_to_baseline(baseline_iters=3)
+def test_runtime(
+ num_cells: int,
+ artificial: bool,
+ connect: bool,
+ connection_prob: float,
+ voltage_solver: str,
+):
+ delta_t = 0.025
+ t_max = 100.0
+
+ def simulate(params):
+ return jx.integrate(
+ net,
+ params=params,
+ t_max=t_max,
+ delta_t=delta_t,
+ voltage_solver=voltage_solver,
+ )
+
+ runtimes = {}
+
+ start_time = time.time()
+ net, params = build_net(
+ num_cells,
+ artificial=artificial,
+ connect=connect,
+ connection_prob=connection_prob,
+ )
+ runtimes["build_time"] = time.time() - start_time
+
+ jitted_simulate = jit(simulate)
+
+ start_time = time.time()
+ _ = jitted_simulate(params).block_until_ready()
+ runtimes["compile_time"] = time.time() - start_time
+ params[0]["radius"] = params[0]["radius"].at[0].set(0.5)
+
+ start_time = time.time()
+ _ = jitted_simulate(params).block_until_ready()
+ runtimes["run_time"] = time.time() - start_time
+ return runtimes # @compare_to_baseline decorator will compare this to the baseline