From ff4f62042e5caae75d1002fd0d696f9f232c26b3 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Wed, 8 Nov 2023 15:19:49 +0100 Subject: [PATCH] Copy ground truth traces to tests --- tests/neurax_identical/swc.py | 64 ---------- ...basic_modules.py => test_basic_modules.py} | 103 +++++++++++++-- ...nd_length.py => test_radius_and_length.py} | 119 ++++++++++++++---- tests/neurax_identical/test_swc.py | 108 ++++++++++++++++ 4 files changed, 294 insertions(+), 100 deletions(-) delete mode 100644 tests/neurax_identical/swc.py rename tests/neurax_identical/{basic_modules.py => test_basic_modules.py} (56%) rename tests/neurax_identical/{radius_and_length.py => test_radius_and_length.py} (61%) create mode 100644 tests/neurax_identical/test_swc.py diff --git a/tests/neurax_identical/swc.py b/tests/neurax_identical/swc.py deleted file mode 100644 index 1f32516d..00000000 --- a/tests/neurax_identical/swc.py +++ /dev/null @@ -1,64 +0,0 @@ -from jax import config - -config.update("jax_enable_x64", True) -config.update("jax_platform_name", "cpu") - -import os - -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8" - -import numpy as np -import jax.numpy as jnp - -import neurax as nx -from neurax.channels import HHChannel -from neurax.synapses import GlutamateSynapse - - -def test_swc_cell(): - dt = 0.025 # ms - t_max = 5.0 # ms - current = nx.step_current(0.5, 1.0, 0.02, time_vec) - - time_vec = jnp.arange(0.0, t_max + dt, dt) - - cell = nx.read_swc() - cell.insert(HHChannel()) - cell.branch(1).comp(0.0).record() - cell.branch(1).comp(0.0).stimulate(current) - - voltages = nx.integrate(cell, delta_t=dt) - - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 - assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" - - -def test_swc_net(): - dt = 0.025 # ms - t_max = 5.0 # ms - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) - - cell1 = nx.read_swc() - cell2 = nx.read_swc() - - connectivities = [ - nx.Connectivity(GlutamateSynapse(), [nx.Connection(0, 0, 0.0, 1, 0, 0.0)]) - ] - network = nx.Network([cell1, cell2], connectivities) - network.insert(HHChannel()) - - for cell_ind in range(2): - network.cell(cell_ind).branch(1).comp(0.0).record() - - for stim_ind in range(2): - network.cell(stim_ind).branch(1).comp(0.0).stimulate(current) - - voltages = nx.integrate(network, delta_t=dt) - - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 - assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" diff --git a/tests/neurax_identical/basic_modules.py b/tests/neurax_identical/test_basic_modules.py similarity index 56% rename from tests/neurax_identical/basic_modules.py rename to tests/neurax_identical/test_basic_modules.py index 123c5efd..3664c513 100644 --- a/tests/neurax_identical/basic_modules.py +++ b/tests/neurax_identical/test_basic_modules.py @@ -7,8 +7,8 @@ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8" -import numpy as np import jax.numpy as jnp +import numpy as np import neurax as nx from neurax.channels import HHChannel @@ -29,9 +29,25 @@ def test_compartment(): voltages = nx.integrate(comp, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + -50.08343917, + -20.4930498, + 32.59313718, + 2.43694172, + -26.28659049, + -50.60083575, + -73.00785374, + -75.70088187, + -75.58850932, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" @@ -51,9 +67,25 @@ def test_branch(): voltages = nx.integrate(branch, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + -56.84345977, + -47.34841899, + -31.44840463, + 31.42082062, + 7.06374191, + -22.55693822, + -46.57781433, + -71.30357569, + -75.67506729, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" @@ -77,9 +109,25 @@ def test_cell(): voltages = nx.integrate(cell, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + -60.77514388, + -57.04487857, + -57.47607922, + -56.17604462, + -54.14448985, + -50.28092319, + -37.61368601, + 24.47352735, + 10.24274831, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" @@ -113,7 +161,36 @@ def test_net(): voltages = nx.integrate(network, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + -60.77514388, + -57.04487857, + -57.47607922, + -56.17604462, + -54.14448985, + -50.28092319, + -37.61368601, + 24.47352735, + 10.24274831, + ], + [ + -70.0, + -66.12980895, + -59.94208128, + -55.74082517, + -55.34657106, + -52.32113275, + -44.76100591, + -3.65687352, + 22.581919, + -8.29885715, + -33.70855517, + ], + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" diff --git a/tests/neurax_identical/radius_and_length.py b/tests/neurax_identical/test_radius_and_length.py similarity index 61% rename from tests/neurax_identical/radius_and_length.py rename to tests/neurax_identical/test_radius_and_length.py index 1de26c8d..748787d7 100644 --- a/tests/neurax_identical/radius_and_length.py +++ b/tests/neurax_identical/test_radius_and_length.py @@ -7,8 +7,8 @@ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8" -import numpy as np import jax.numpy as jnp +import numpy as np import neurax as nx from neurax.channels import HHChannel @@ -20,6 +20,7 @@ def test_radius_and_length_compartment(): t_max = 5.0 # ms time_vec = jnp.arange(0.0, t_max + dt, dt) + current = nx.step_current(0.5, 1.0, 0.02, time_vec) comp = nx.Compartment().initialize() @@ -28,16 +29,30 @@ def test_radius_and_length_compartment(): comp.set_params("radius", np.random.rand(1)) comp.insert(HHChannel()) - - current = nx.step_current(0.5, 1.0, 0.02, time_vec) comp.record() comp.stimulate(current) voltages = nx.integrate(comp, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + 45.83306656, + 29.72581199, + -15.44336119, + -39.98282246, + -66.77430474, + -75.56725325, + -75.75072145, + -75.48894182, + -75.15588746, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" @@ -47,6 +62,7 @@ def test_radius_and_length_branch(): t_max = 5.0 # ms time_vec = jnp.arange(0.0, t_max + dt, dt) + current = nx.step_current(0.5, 1.0, 0.02, time_vec) comp = nx.Compartment().initialize() branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize() @@ -56,16 +72,30 @@ def test_radius_and_length_branch(): branch.set_params("radius", np.random.rand(2)) branch.insert(HHChannel()) - - current = nx.step_current(0.5, 1.0, 0.02, time_vec) branch.comp(0.0).record() branch.comp(0.0).stimulate(current) voltages = nx.integrate(branch, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + 57.69711962, + 27.50364167, + -20.46106389, + -44.65846514, + -70.81919426, + -75.78039147, + -75.74705134, + -75.46858134, + -75.13087565, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" @@ -75,6 +105,7 @@ def test_radius_and_length_cell(): t_max = 5.0 # ms time_vec = jnp.arange(0.0, t_max + dt, dt) + current = nx.step_current(0.5, 1.0, 0.02, time_vec) depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] @@ -89,16 +120,30 @@ def test_radius_and_length_cell(): cell.set_params("radius", np.random.rand(2 * num_branches)) cell.insert(HHChannel()) - - current = nx.step_current(0.5, 1.0, 0.02, time_vec) cell.branch(1).comp(0.0).record() cell.branch(1).comp(0.0).stimulate(current) - voltage = nx.integrate(cell, delta_t=dt) - - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages = nx.integrate(cell, delta_t=dt) + + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + 1.40098029, + 41.85849945, + -3.41062602, + -30.40531156, + -54.43060925, + -73.96911419, + -75.74495833, + -75.58187443, + -75.27068799, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" @@ -108,6 +153,7 @@ def test_radius_and_length_net(): t_max = 5.0 # ms time_vec = jnp.arange(0.0, t_max + dt, dt) + current = nx.step_current(0.5, 1.0, 0.02, time_vec) depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] @@ -132,8 +178,6 @@ def test_radius_and_length_net(): network = nx.Network([cell1, cell2], connectivities) network.insert(HHChannel()) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) - for cell_ind in range(2): network.cell(cell_ind).branch(1).comp(0.0).record() @@ -142,7 +186,36 @@ def test_radius_and_length_net(): voltages = nx.integrate(network, delta_t=dt) - voltages_081123 = None - max_error = np.max(np.abs(voltages[:, ::10] - voltages_081123)) - tolerance = 0.0 + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + 1.40098029, + 41.85849945, + -3.41062602, + -30.40531156, + -54.43060925, + -73.96911419, + -75.74495833, + -75.58187443, + -75.27068799, + ], + [ + -70.0, + -66.46899201, + -14.64499375, + 43.92453203, + 3.25054262, + -25.60587037, + -49.03042036, + -71.89299285, + -75.57211264, + -75.4917933, + -75.16938855, + ], + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" diff --git a/tests/neurax_identical/test_swc.py b/tests/neurax_identical/test_swc.py new file mode 100644 index 00000000..f49c1405 --- /dev/null +++ b/tests/neurax_identical/test_swc.py @@ -0,0 +1,108 @@ +from jax import config + +config.update("jax_enable_x64", True) +config.update("jax_platform_name", "cpu") + +import os + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".8" + +import jax.numpy as jnp +import numpy as np + +import neurax as nx +from neurax.channels import HHChannel +from neurax.synapses import GlutamateSynapse + + +def test_swc_cell(): + dt = 0.025 # ms + t_max = 5.0 # ms + time_vec = jnp.arange(0.0, t_max + dt, dt) + current = nx.step_current(0.5, 1.0, 0.2, time_vec) + + cell = nx.read_swc("../morph.swc", nseg=2, max_branch_len=300.0) + cell.insert(HHChannel()) + cell.branch(1).comp(0.0).record() + cell.branch(1).comp(0.0).stimulate(current) + + voltages = nx.integrate(cell, delta_t=dt) + + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + -57.02065054, + -49.74541341, + -46.15576812, + -23.71760359, + 25.19649297, + 1.99881676, + -25.42530891, + -48.23078669, + -69.2479302, + ] + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 + assert max_error <= tolerance, f"Error is {max_error} > {tolerance}" + + +def test_swc_net(): + dt = 0.025 # ms + t_max = 5.0 # ms + time_vec = jnp.arange(0.0, t_max + dt, dt) + current = nx.step_current(0.5, 1.0, 0.2, time_vec) + + cell1 = nx.read_swc("../morph.swc", nseg=2, max_branch_len=300.0) + cell2 = nx.read_swc("../morph.swc", nseg=2, max_branch_len=300.0) + + connectivities = [ + nx.Connectivity(GlutamateSynapse(), [nx.Connection(0, 0, 0.0, 1, 0, 0.0)]) + ] + network = nx.Network([cell1, cell2], connectivities) + network.insert(HHChannel()) + + for cell_ind in range(2): + network.cell(cell_ind).branch(1).comp(0.0).record() + + for stim_ind in range(2): + network.cell(stim_ind).branch(1).comp(0.0).stimulate(current) + + voltages = nx.integrate(network, delta_t=dt) + + voltages_081123 = jnp.asarray( + [ + [ + -70.0, + -66.53085703, + -57.02065054, + -49.74541341, + -46.15576812, + -23.71760359, + 25.19649297, + 1.99881676, + -25.42530891, + -48.23078669, + -69.2479302, + ], + [ + -70.0, + -66.52400879, + -57.01032453, + -49.72868896, + -46.10615669, + -23.43965885, + 25.17580247, + 1.83985193, + -25.55061179, + -48.34838475, + -69.29381344, + ], + ] + ) + max_error = np.max(np.abs(voltages[:, ::20] - voltages_081123)) + tolerance = 1e-8 + assert max_error <= tolerance, f"Error is {max_error} > {tolerance}"