Skip to content

Commit

Permalink
Copy ground truth traces to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 8, 2023
1 parent 5cd7f2d commit ff4f620
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 100 deletions.
64 changes: 0 additions & 64 deletions tests/neurax_identical/swc.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8"

import numpy as np
import jax.numpy as jnp
import numpy as np

import neurax as nx
from neurax.channels import HHChannel
Expand All @@ -29,9 +29,25 @@ def test_compartment():

voltages = nx.integrate(comp, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-50.08343917,
-20.4930498,
32.59313718,
2.43694172,
-26.28659049,
-50.60083575,
-73.00785374,
-75.70088187,
-75.58850932,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -51,9 +67,25 @@ def test_branch():

voltages = nx.integrate(branch, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-56.84345977,
-47.34841899,
-31.44840463,
31.42082062,
7.06374191,
-22.55693822,
-46.57781433,
-71.30357569,
-75.67506729,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -77,9 +109,25 @@ def test_cell():

voltages = nx.integrate(cell, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-60.77514388,
-57.04487857,
-57.47607922,
-56.17604462,
-54.14448985,
-50.28092319,
-37.61368601,
24.47352735,
10.24274831,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand Down Expand Up @@ -113,7 +161,36 @@ def test_net():

voltages = nx.integrate(network, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-60.77514388,
-57.04487857,
-57.47607922,
-56.17604462,
-54.14448985,
-50.28092319,
-37.61368601,
24.47352735,
10.24274831,
],
[
-70.0,
-66.12980895,
-59.94208128,
-55.74082517,
-55.34657106,
-52.32113275,
-44.76100591,
-3.65687352,
22.581919,
-8.29885715,
-33.70855517,
],
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
Loading

0 comments on commit ff4f620

Please sign in to comment.