Skip to content

Commit

Permalink
Copy ground truth traces to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 8, 2023
1 parent 9e57537 commit fd724a0
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 98 deletions.
64 changes: 0 additions & 64 deletions tests/neurax_identical/swc.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,25 @@ def test_compartment():

voltages = nx.integrate(comp, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-50.08343917,
-20.4930498,
32.59313718,
2.43694172,
-26.28659049,
-50.60083575,
-73.00785374,
-75.70088187,
-75.58850932,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -51,9 +67,25 @@ def test_branch():

voltages = nx.integrate(branch, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-56.84345977,
-47.34841899,
-31.44840463,
31.42082062,
7.06374191,
-22.55693822,
-46.57781433,
-71.30357569,
-75.67506729,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -77,9 +109,25 @@ def test_cell():

voltages = nx.integrate(cell, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-60.77514388,
-57.04487857,
-57.47607922,
-56.17604462,
-54.14448985,
-50.28092319,
-37.61368601,
24.47352735,
10.24274831,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand Down Expand Up @@ -113,7 +161,36 @@ def test_net():

voltages = nx.integrate(network, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
-60.77514388,
-57.04487857,
-57.47607922,
-56.17604462,
-54.14448985,
-50.28092319,
-37.61368601,
24.47352735,
10.24274831,
],
[
-70.0,
-66.12980895,
-59.94208128,
-55.74082517,
-55.34657106,
-52.32113275,
-44.76100591,
-3.65687352,
22.581919,
-8.29885715,
-33.70855517,
],
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_radius_and_length_compartment():
t_max = 5.0 # ms

time_vec = jnp.arange(0.0, t_max + dt, dt)
current = nx.step_current(0.5, 1.0, 0.02, time_vec)

comp = nx.Compartment().initialize()

Expand All @@ -28,16 +29,30 @@ def test_radius_and_length_compartment():
comp.set_params("radius", np.random.rand(1))

comp.insert(HHChannel())

current = nx.step_current(0.5, 1.0, 0.02, time_vec)
comp.record()
comp.stimulate(current)

voltages = nx.integrate(comp, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
45.83306656,
29.72581199,
-15.44336119,
-39.98282246,
-66.77430474,
-75.56725325,
-75.75072145,
-75.48894182,
-75.15588746,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -47,6 +62,7 @@ def test_radius_and_length_branch():
t_max = 5.0 # ms

time_vec = jnp.arange(0.0, t_max + dt, dt)
current = nx.step_current(0.5, 1.0, 0.02, time_vec)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
Expand All @@ -56,16 +72,30 @@ def test_radius_and_length_branch():
branch.set_params("radius", np.random.rand(2))

branch.insert(HHChannel())

current = nx.step_current(0.5, 1.0, 0.02, time_vec)
branch.comp(0.0).record()
branch.comp(0.0).stimulate(current)

voltages = nx.integrate(branch, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
57.69711962,
27.50364167,
-20.46106389,
-44.65846514,
-70.81919426,
-75.78039147,
-75.74705134,
-75.46858134,
-75.13087565,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -75,6 +105,7 @@ def test_radius_and_length_cell():
t_max = 5.0 # ms

time_vec = jnp.arange(0.0, t_max + dt, dt)
current = nx.step_current(0.5, 1.0, 0.02, time_vec)

depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
Expand All @@ -89,16 +120,30 @@ def test_radius_and_length_cell():
cell.set_params("radius", np.random.rand(2 * num_branches))

cell.insert(HHChannel())

current = nx.step_current(0.5, 1.0, 0.02, time_vec)
cell.branch(1).comp(0.0).record()
cell.branch(1).comp(0.0).stimulate(current)

voltage = nx.integrate(cell, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages = nx.integrate(cell, delta_t=dt)

voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
1.40098029,
41.85849945,
-3.41062602,
-30.40531156,
-54.43060925,
-73.96911419,
-75.74495833,
-75.58187443,
-75.27068799,
]
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"


Expand All @@ -108,6 +153,7 @@ def test_radius_and_length_net():
t_max = 5.0 # ms

time_vec = jnp.arange(0.0, t_max + dt, dt)
current = nx.step_current(0.5, 1.0, 0.02, time_vec)

depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
Expand All @@ -132,8 +178,6 @@ def test_radius_and_length_net():
network = nx.Network([cell1, cell2], connectivities)
network.insert(HHChannel())

current = nx.step_current(0.5, 1.0, 0.02, time_vec)

for cell_ind in range(2):
network.cell(cell_ind).branch(1).comp(0.0).record()

Expand All @@ -142,7 +186,36 @@ def test_radius_and_length_net():

voltages = nx.integrate(network, delta_t=dt)

voltages_081123 = None
max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123))
tolerance = 0.0
voltages_081123 = jnp.asarray(
[
[
-70.0,
-66.53085703,
1.40098029,
41.85849945,
-3.41062602,
-30.40531156,
-54.43060925,
-73.96911419,
-75.74495833,
-75.58187443,
-75.27068799,
],
[
-70.0,
-66.46899201,
-14.64499375,
43.92453203,
3.25054262,
-25.60587037,
-49.03042036,
-71.89299285,
-75.57211264,
-75.4917933,
-75.16938855,
],
]
)
max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123))
tolerance = 1e-8
assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"
Loading

0 comments on commit fd724a0

Please sign in to comment.