From 5ad4d699d6507efcb69822c4f5e89d7699e2c362 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 3 Oct 2024 15:39:00 +0200 Subject: [PATCH 1/8] Allow ion diffusion --- jaxley/modules/base.py | 100 ++++++++++++++++++++++++--------------- jaxley/solver_voltage.py | 4 +- 2 files changed, 65 insertions(+), 39 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c3e6a34c..a5edcac8 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -91,6 +91,9 @@ def __init__(self): self.channels: List[Channel] = [] self.membrane_current_names: List[str] = [] + # List of all states (exluding voltage) that are being diffused. + self.diffusion_states: List[str] = [] + # For trainable parameters. self.indices_set_by_trainables: List[jnp.ndarray] = [] self.trainable_params: List[Dict[str, jnp.ndarray]] = [] @@ -395,6 +398,10 @@ def _data_set( raise KeyError("Key not recognized.") return param_state + def diffuse(self, state: str): + self.diffusion_states.append(state) + self.nodes[f"axial_resistivity_{state}"] = 1.0 + def make_trainable( self, key: str, @@ -548,6 +555,11 @@ def get_all_parameters( for key in ["radius", "length", "axial_resistivity", "capacitance"]: params[key] = self.jaxnodes[key] + for key in self.diffusion_states: + params[f"axial_resistivity_{key}"] = self.jaxnodes[ + f"axial_resistivity_{key}" + ] + for channel in self.channels: for channel_params in channel.channel_params: params[channel_params] = self.jaxnodes[channel_params] @@ -952,25 +964,32 @@ def step( cm = params["capacitance"] # Abbreviation. # Arguments used by all solvers. - solver_kwargs = { - "voltages": voltages, - "voltage_terms": (v_terms + syn_v_terms) / cm, - "constant_terms": (const_terms + i_ext + syn_const_terms) / cm, - "axial_conductances": params["axial_conductances"], - "internal_node_inds": self._internal_node_inds, + state_vals = { + "voltages": jnp.stack([voltages, u["CaCon_i"]]), + "voltage_terms": jnp.stack( + [(v_terms + syn_v_terms) / cm, jnp.zeros_like(v_terms)] + ), + "constant_terms": jnp.stack( + [ + (const_terms + i_ext + syn_const_terms) / cm, + jnp.zeros_like(const_terms), + ] + ), + "axial_conductances": jnp.stack( + [params["axial_conductances"], params["axial_conductances"]] + ), } # Add solver specific arguments. if voltage_solver == "jax.sparse": - solver_kwargs.update( - { - "sinks": np.asarray(self._comp_edges["sink"].to_list()), - "data_inds": self._data_inds, - "indices": self._indices_jax_spsolve, - "indptr": self._indptr_jax_spsolve, - "n_nodes": self._n_nodes, - } - ) + solver_kwargs = { + "data_inds": self._data_inds, + "indices": self._indices_jax_spsolve, + "indptr": self._indptr_jax_spsolve, + "sinks": np.asarray(self._comp_edges["sink"].to_list()), + "n_nodes": self._n_nodes, + "internal_node_inds": self._internal_node_inds, + } # Only for `bwd_euler` and `cranck-nicolson`. step_voltage_implicit = step_voltage_implicit_with_jax_spsolve else: @@ -980,42 +999,49 @@ def step( # Currently, the forward Euler solver also uses this format. However, # this is only for historical reasons and we are planning to change this in # the future. - solver_kwargs.update( - { - "sinks": np.asarray(self._comp_edges["sink"].to_list()), - "sources": np.asarray(self._comp_edges["source"].to_list()), - "types": np.asarray(self._comp_edges["type"].to_list()), - "masked_node_inds": self._remapped_node_indices, - "nseg_per_branch": self.nseg_per_branch, - "nseg": self.nseg, - "par_inds": self.par_inds, - "child_inds": self.child_inds, - "nbranches": self.total_nbranches, - "solver": voltage_solver, - "children_in_level": self.children_in_level, - "parents_in_level": self.parents_in_level, - "root_inds": self.root_inds, - "branchpoint_group_inds": self.branchpoint_group_inds, - "debug_states": self.debug_states, - } - ) + solver_kwargs = { + "internal_node_inds": self._internal_node_inds, + "sinks": np.asarray(self._comp_edges["sink"].to_list()), + "sources": np.asarray(self._comp_edges["source"].to_list()), + "types": np.asarray(self._comp_edges["type"].to_list()), + "masked_node_inds": self._remapped_node_indices, + "nseg_per_branch": self.nseg_per_branch, + "nseg": self.nseg, + "par_inds": self.par_inds, + "child_inds": self.child_inds, + "nbranches": self.total_nbranches, + "solver": voltage_solver, + "children_in_level": self.children_in_level, + "parents_in_level": self.parents_in_level, + "root_inds": self.root_inds, + "branchpoint_group_inds": self.branchpoint_group_inds, + "debug_states": self.debug_states, + } # Only for `bwd_euler` and `cranck-nicolson`. step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve if solver == "bwd_euler": - u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t) + nones = [None] * len(solver_kwargs) + vmapped = vmap(step_voltage_implicit, in_axes=(0, 0, 0, 0, *nones, None)) + updated_states = vmapped( + *state_vals.values(), *solver_kwargs.values(), delta_t + ) + u["v"] = updated_states[0] + u["CaCon_i"] = updated_states[1] elif solver == "crank_nicolson": # Crank-Nicolson advances by half a step of backward and half a step of # forward Euler. half_step_delta_t = delta_t / 2 half_step_voltages = step_voltage_implicit( - **solver_kwargs, delta_t=half_step_delta_t + **state_vals, **solver_kwargs, delta_t=half_step_delta_t ) # The forward Euler step in Crank-Nicolson can be performed easily as # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4. u["v"] = 2 * half_step_voltages - voltages elif solver == "fwd_euler": - u["v"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t) + u["v"] = step_voltage_explicit( + **state_vals, **solver_kwargs, delta_t=delta_t + ) else: raise ValueError( f"You specified `solver={solver}`. The only allowed solvers are " diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 80c1538a..fb014d0a 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -80,12 +80,12 @@ def step_voltage_implicit_with_jaxley_spsolve( child_inds: jnp.ndarray, nbranches: int, solver: str, - delta_t: float, children_in_level: List[jnp.ndarray], parents_in_level: List[jnp.ndarray], root_inds: jnp.ndarray, branchpoint_group_inds: jnp.ndarray, debug_states, + delta_t: float, ): """Solve one timestep of branched nerve equations with implicit (backward) Euler.""" # Build diagonals. @@ -246,9 +246,9 @@ def step_voltage_implicit_with_jax_spsolve( indices, indptr, sinks, - delta_t, n_nodes, internal_node_inds, + delta_t, ): axial_conductances = delta_t * axial_conductances From 0ecd0ff943d150f277665a52217feac397df705a Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 4 Oct 2024 14:44:56 +0200 Subject: [PATCH 2/8] more features --- jaxley/modules/base.py | 2 +- jaxley/solver_voltage.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index a5edcac8..d25809e9 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -976,7 +976,7 @@ def step( ] ), "axial_conductances": jnp.stack( - [params["axial_conductances"], params["axial_conductances"]] + [params["axial_conductances"], 0.0 * params["axial_conductances"]] ), } diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index fb014d0a..8a4a33aa 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -149,7 +149,7 @@ def step_voltage_implicit_with_jaxley_spsolve( num_branchpoints = len(branchpoint_conds_parents) branchpoint_diags = -group_and_sum( all_branchpoint_vals, branchpoint_group_inds, num_branchpoints - ) + ) + 1e-14 # For numerical stability if axial_conductances == 0.0 branchpoint_solves = jnp.zeros((num_branchpoints,)) branchpoint_conds_children = -delta_t * branchpoint_conds_children From 76696ae9c8ae4fc067db96c741d4d82af79eb90f Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 18 Oct 2024 19:38:41 +0200 Subject: [PATCH 3/8] minor --- jaxley/modules/compartment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index 6cd24f9f..aac0414f 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -35,6 +35,7 @@ class Compartment(Module): def __init__(self): super().__init__() + print("Hello =========") self.nseg = 1 self.nseg_per_branch = [1] From 56782bc1737d63c36a385d9539bd8071df0498f5 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 18 Oct 2024 20:31:31 +0200 Subject: [PATCH 4/8] revise `.step()` --- jaxley/modules/base.py | 52 +++++++++++++++++++++++------------ jaxley/modules/compartment.py | 1 - 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d25809e9..a8443c60 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -399,8 +399,24 @@ def _data_set( return param_state def diffuse(self, state: str): + """Diffuse a particular state across compartments with Fickian diffusion. + + Args: + state: Name of the state that should be diffused. + """ + self._diffuse(state, self.nodes, self.nodes) + + def _diffuse(self, state: str, table_to_update: pd.DataFrame, view: pd.DataFrame): self.diffusion_states.append(state) - self.nodes[f"axial_resistivity_{state}"] = 1.0 + table_to_update.loc[view.index.values, f"axial_resistivity_{state}"] = 1.0 + + # The diffused state might not exist in all compartments that across which + # we are diffusing (e.g. there are active calcium mechanisms only in the soma, + # but calcium should still diffuse into the dendrites). Here, we ensure that + # the state is not `NaN` in every compartment across which we are diffusing. + state_is_nan = pd.isna(view[state]) + average_state_value = view[state].mean() + table_to_update.loc[state_is_nan, state] = average_state_value def make_trainable( self, @@ -912,7 +928,7 @@ def step( This function is called inside of `integrate` and increments the state of the module by one time step. Calls `_step_channels` and `_step_synapse` to update - the states of the channels and synapses using fwd_euler. + the states of the channels and synapses. Args: u: The state of the module. voltages = u["v"] @@ -964,19 +980,18 @@ def step( cm = params["capacitance"] # Abbreviation. # Arguments used by all solvers. + num_diffused_states = len(self.diffusion_states) + diffused_state_zeros = [jnp.zeros_like(v_terms)] * num_diffused_states state_vals = { - "voltages": jnp.stack([voltages, u["CaCon_i"]]), + "voltages": jnp.stack([voltages] + [u[d] for d in self.diffusion_states]), "voltage_terms": jnp.stack( - [(v_terms + syn_v_terms) / cm, jnp.zeros_like(v_terms)] + [(v_terms + syn_v_terms) / cm] + diffused_state_zeros ), "constant_terms": jnp.stack( - [ - (const_terms + i_ext + syn_const_terms) / cm, - jnp.zeros_like(const_terms), - ] + [(const_terms + i_ext + syn_const_terms) / cm] + diffused_state_zeros ), "axial_conductances": jnp.stack( - [params["axial_conductances"], 0.0 * params["axial_conductances"]] + [params["axial_conductances"]] + [0.0 * params["axial_conductances"] for _ in range(num_diffused_states)] ), } @@ -1027,7 +1042,9 @@ def step( *state_vals.values(), *solver_kwargs.values(), delta_t ) u["v"] = updated_states[0] - u["CaCon_i"] = updated_states[1] + for i, diffusion_state in enumerate(self.diffusion_states): + # +1 because voltage is the zero-eth element. + u[diffusion_state] = updated_states[i + 1] elif solver == "crank_nicolson": # Crank-Nicolson advances by half a step of backward and half a step of # forward Euler. @@ -1713,14 +1730,13 @@ def data_set( """Set parameter of module (or its view) to a new value within `jit`.""" return self.pointer._data_set(key, val, self.view, param_state) - def make_trainable( - self, - key: str, - init_val: Optional[Union[float, list]] = None, - verbose: bool = True, - ): - """Make a parameter trainable.""" - self.pointer._make_trainable(self.view, key, init_val, verbose=verbose) + def diffuse(self, state: str): + """Diffuse a particular state across compartments with Fickian diffusion. + + Args: + state: Name of the state that should be diffused. + """ + self._diffuse(state, self.pointer.nodes, self.view) def add_to_group(self, group_name: str): self.pointer._add_to_group(group_name, self.view) diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index aac0414f..6cd24f9f 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -35,7 +35,6 @@ class Compartment(Module): def __init__(self): super().__init__() - print("Hello =========") self.nseg = 1 self.nseg_per_branch = [1] From b19aea145a3ec48ab84d01b2b4a2d4c4e916b34f Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 18 Oct 2024 21:07:32 +0200 Subject: [PATCH 5/8] actually compute the axial resistivity --- jaxley/modules/base.py | 16 +++++++++------- jaxley/solver_voltage.py | 7 ++++--- jaxley/utils/cell_utils.py | 26 ++++++++++++++------------ 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index a8443c60..06cd1c0a 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -290,8 +290,12 @@ def _init_morph_jaxley_spsolve(self): raise NotImplementedError def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]): - """Given radius, length, r_a, compute the axial coupling conductances.""" - return compute_axial_conductances(self._comp_edges, params) + """Given radius, length, r_a, compute the axial coupling conductances. + + If ion diffusion was activated by the user (with `cell.diffuse()`) then this + function also compute the axial conductances for every ion. + """ + return compute_axial_conductances(self._comp_edges, params, self.diffusion_states) def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"): """Adds channel nodes from constituents to `self.channel_nodes`.""" @@ -400,7 +404,7 @@ def _data_set( def diffuse(self, state: str): """Diffuse a particular state across compartments with Fickian diffusion. - + Args: state: Name of the state that should be diffused. """ @@ -990,9 +994,7 @@ def step( "constant_terms": jnp.stack( [(const_terms + i_ext + syn_const_terms) / cm] + diffused_state_zeros ), - "axial_conductances": jnp.stack( - [params["axial_conductances"]] + [0.0 * params["axial_conductances"] for _ in range(num_diffused_states)] - ), + "axial_conductances": params["axial_conductances"], } # Add solver specific arguments. @@ -1732,7 +1734,7 @@ def data_set( def diffuse(self, state: str): """Diffuse a particular state across compartments with Fickian diffusion. - + Args: state: Name of the state that should be diffused. """ diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 8a4a33aa..1cc184f3 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -147,9 +147,10 @@ def step_voltage_implicit_with_jaxley_spsolve( ) # Find unique group identifiers num_branchpoints = len(branchpoint_conds_parents) - branchpoint_diags = -group_and_sum( - all_branchpoint_vals, branchpoint_group_inds, num_branchpoints - ) + 1e-14 # For numerical stability if axial_conductances == 0.0 + branchpoint_diags = ( + -group_and_sum(all_branchpoint_vals, branchpoint_group_inds, num_branchpoints) + + 1e-14 + ) # For numerical stability if axial_conductances == 0.0 branchpoint_solves = jnp.zeros((num_branchpoints,)) branchpoint_conds_children = -delta_t * branchpoint_conds_children diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index ffa8c0e3..8f0bec48 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -410,7 +410,7 @@ def query_channel_states_and_params(d, keys, idcs): def compute_axial_conductances( - comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray] + comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray], diffusion_states: List[str] ) -> jnp.ndarray: """Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances. @@ -422,13 +422,15 @@ def compute_axial_conductances( source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) + resistivities = jnp.stack([params["axial_resistivity"]] + [params[f"axial_resistivity_{d}"] for d in diffusion_states]) + if len(sink_comp_inds) > 0: conds_c2c = ( - vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))( + vmap(vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None))( params["radius"][sink_comp_inds], params["radius"][source_comp_inds], - params["axial_resistivity"][sink_comp_inds], - params["axial_resistivity"][source_comp_inds], + resistivities[sink_comp_inds], + resistivities[source_comp_inds], params["length"][sink_comp_inds], params["length"][source_comp_inds], ) @@ -443,34 +445,34 @@ def compute_axial_conductances( if len(sink_comp_inds) > 0: conds_bp2c = ( - vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))( + vmap(vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)), in_axes=(None, 0, None))( params["radius"][sink_comp_inds], - params["axial_resistivity"][sink_comp_inds], + resistivities[sink_comp_inds], params["length"][sink_comp_inds], ) - / params["capacitance"][sink_comp_inds] + / params["capacitance"][sink_comp_inds] # TODO only v should divide by capacitance. ) else: - conds_bp2c = jnp.asarray([]) + conds_bp2c = jnp.asarray([[]] * (len(diffusion_states) + 1)) # `compartment-to-branchpoint` (c2bp) axial coupling conductances. condition = comp_edges["type"].isin([3, 4]) source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) if len(source_comp_inds) > 0: - conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))( + conds_c2bp = vmap(vmap(compute_impact_on_node, in_axes=(0, 0, 0)), in_axes=(0, None, 0))( params["radius"][source_comp_inds], - params["axial_resistivity"][source_comp_inds], + resistivities[source_comp_inds], params["length"][source_comp_inds], ) # For numerical stability. These values are very small, but their scale # does not matter. conds_c2bp *= 1_000 else: - conds_c2bp = jnp.asarray([]) + conds_c2bp = jnp.asarray([[]] * (len(diffusion_states) + 1)) # All axial coupling conductances. - return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp]) + return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp], axis=1) def compute_children_and_parents( From b6c963f2f097a27159bdcc523c2c852e56488711 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Fri, 18 Oct 2024 21:15:33 +0200 Subject: [PATCH 6/8] bugfix --- jaxley/utils/cell_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index 8f0bec48..8b77118d 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -423,21 +423,22 @@ def compute_axial_conductances( sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) resistivities = jnp.stack([params["axial_resistivity"]] + [params[f"axial_resistivity_{d}"] for d in diffusion_states]) + print("resistivities", resistivities.shape) if len(sink_comp_inds) > 0: conds_c2c = ( vmap(vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None))( params["radius"][sink_comp_inds], params["radius"][source_comp_inds], - resistivities[sink_comp_inds], - resistivities[source_comp_inds], + resistivities[:, sink_comp_inds], + resistivities[:, source_comp_inds], params["length"][sink_comp_inds], params["length"][source_comp_inds], ) / params["capacitance"][sink_comp_inds] ) else: - conds_c2c = jnp.asarray([]) + conds_c2c = jnp.asarray([[]] * (len(diffusion_states) + 1)) # `branchpoint-to-compartment` (bp2c) axial coupling conductances. condition = comp_edges["type"].isin([1, 2]) @@ -447,7 +448,7 @@ def compute_axial_conductances( conds_bp2c = ( vmap(vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)), in_axes=(None, 0, None))( params["radius"][sink_comp_inds], - resistivities[sink_comp_inds], + resistivities[:, sink_comp_inds], params["length"][sink_comp_inds], ) / params["capacitance"][sink_comp_inds] # TODO only v should divide by capacitance. @@ -462,7 +463,7 @@ def compute_axial_conductances( if len(source_comp_inds) > 0: conds_c2bp = vmap(vmap(compute_impact_on_node, in_axes=(0, 0, 0)), in_axes=(0, None, 0))( params["radius"][source_comp_inds], - resistivities[source_comp_inds], + resistivities[:, source_comp_inds], params["length"][source_comp_inds], ) # For numerical stability. These values are very small, but their scale From a272370668deedffa2fae596bdec4dff932f32e9 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Sun, 20 Oct 2024 16:05:42 +0200 Subject: [PATCH 7/8] add pumps --- jaxley/pumps/__init__.py | 0 jaxley/pumps/ca_pump.py | 39 ++++++++++++++++++++ jaxley/pumps/pump.py | 80 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 jaxley/pumps/__init__.py create mode 100644 jaxley/pumps/ca_pump.py create mode 100644 jaxley/pumps/pump.py diff --git a/jaxley/pumps/__init__.py b/jaxley/pumps/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/jaxley/pumps/ca_pump.py b/jaxley/pumps/ca_pump.py new file mode 100644 index 00000000..73ebcc12 --- /dev/null +++ b/jaxley/pumps/ca_pump.py @@ -0,0 +1,39 @@ +from typing import Optional + +from jaxley.pumps.pump import Pump + + +class CaPump(Pump): + """Calcium dynamics tracking inside calcium concentration + + Modeled after Destexhe et al. 1994. + """ + + def __init__(self, name: Optional[str] = None): + super().__init__(name) + self.pump_params = { + f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered). + f"{self._name}_decay": 80, # Buffering time constant in ms. + f"{self._name}_depth": 0.1, # Depth of shell in um. + f"{self._name}_minCai": 1e-4, # Minimum intracell. ca concentration in mM. + } + self.pump_states = {"CaCon_i": 5e-05} + self.META = { + "reference": "Modified from Destexhe et al., 1994", + "mechanism": "Calcium dynamics", + } + + def compute_current(self, u, dt, voltages, params): + """Return change of calcium concentration based on calcium current and decay.""" + prefix = self._name + ica = u["i_Ca"] / 1_000.0 + gamma = params[f"{prefix}_gamma"] + decay = params[f"{prefix}_decay"] + depth = params[f"{prefix}_depth"] + minCai = params[f"{prefix}_minCai"] + + FARADAY = 96485 # Coulombs per mole. + + # Calculate the contribution of calcium currents to cai change. + drive_channel = -10_000.0 * ica * gamma / (2 * FARADAY * depth) + return drive_channel - (u["CaCon_i"] + minCai) / decay diff --git a/jaxley/pumps/pump.py b/jaxley/pumps/pump.py new file mode 100644 index 00000000..23176f99 --- /dev/null +++ b/jaxley/pumps/pump.py @@ -0,0 +1,80 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Tuple + +import jax.numpy as jnp + + +class Pump: + """Pump base class. All pumps inherit from this class. + + A pump in Jaxley is everything that modifies the intracellular ion concentrations. + """ + + _name = None + pump_params = None + pump_states = None + current_name = None + + def __init__(self, name: Optional[str] = None): + self._name = name if name else self.__class__.__name__ + + @property + def name(self) -> Optional[str]: + """The name of the channel (by default, this is the class name).""" + return self._name + + def change_name(self, new_name: str): + """Change the pump name. + + Args: + new_name: The new name of the pump. + + Returns: + Renamed pump, such that this function is chainable. + """ + old_prefix = self._name + "_" + new_prefix = new_name + "_" + + self._name = new_name + self.pump_params = { + ( + new_prefix + key[len(old_prefix) :] + if key.startswith(old_prefix) + else key + ): value + for key, value in self.pump_params.items() + } + + self.pump_states = { + ( + new_prefix + key[len(old_prefix) :] + if key.startswith(old_prefix) + else key + ): value + for key, value in self.pump_states.items() + } + return self + + def update_states( + self, states, dt, v, params + ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + """Return the updated states.""" + raise NotImplementedError + + def compute_current( + self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] + ): + """Given channel states and voltage, return the current through the channel. + + Args: + states: All states of the compartment. + v: Voltage of the compartment in mV. + params: Parameters of the channel (conductances in `S/cm2`). + + Returns: + Current in `uA/cm2`. + """ + raise NotImplementedError From c2b2985bc1906d0c6b8d00bf1594a1c546e3b2ff Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Mon, 21 Oct 2024 13:59:25 +0200 Subject: [PATCH 8/8] more work on pumps --- jaxley/modules/base.py | 39 +++++++++++++++++++++++++++++++++++++++ jaxley/pumps/__init__.py | 2 ++ jaxley/pumps/ca_pump.py | 7 ++++++- jaxley/pumps/pump.py | 10 ++-------- 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 06cd1c0a..716f139f 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -14,6 +14,7 @@ from matplotlib.axes import Axes from jaxley.channels import Channel +from jaxley.pumps import Pump from jaxley.solver_voltage import ( step_voltage_explicit, step_voltage_implicit_with_jax_spsolve, @@ -45,6 +46,9 @@ class Module(ABC): This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks). + + Note that the `__init__()` method is not abstract. This is because each module + type has a different initialization procedure. """ def __init__(self): @@ -320,6 +324,26 @@ def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"): for key in channel.channel_states: self.nodes.loc[view.index.values, key] = channel.channel_states[key] + def _append_pump_to_nodes(self, view: pd.DataFrame, pump: "jx.Pump"): + """Adds pump nodes from constituents to `self.pump_nodes`.""" + name = pump._name + + # Pump does not yet exist in the `jx.Module` at all. + if name not in [c._name for c in self.pumps]: + self.pumps.append(pump) + self.nodes[name] = False # Previous columns do not have the new pump. + + # Add a binary column that indicates if the pump is present. + self.nodes.loc[view.index.values, name] = True + + # Loop over all new parameters. + for key in pump.pump_params: + self.nodes.loc[view.index.values, key] = pump.pump_params[key] + + # Loop over all new states. + for key in pump.pump_states: + self.nodes.loc[view.index.values, key] = pump.pump_states[key] + def set(self, key: str, val: Union[float, jnp.ndarray]): """Set parameter of module (or its view) to a new value. @@ -915,6 +939,16 @@ def insert(self, channel: Channel): def _insert(self, channel, view): self._append_channel_to_nodes(view, channel) + def pump(self, pump: Pump): + """Insert a pump into the module. + + Args: + pump: The pump to insert.""" + self._pump(pump, self.nodes) + + def _pump(self, pump, view): + self._append_pump_to_nodes(view, pump) + def init_syns(self): self.initialized_syns = True @@ -966,6 +1000,11 @@ def step( u, delta_t, self.channels, self.nodes, params ) + # # Step of the Pumps. + # u, (v_terms, const_terms) = self._step_pumps( + # u, delta_t, self.pumps, self.nodes, params + # ) + # Step of the synapse. u, (syn_v_terms, syn_const_terms) = self._step_synapse( u, diff --git a/jaxley/pumps/__init__.py b/jaxley/pumps/__init__.py index e69de29b..ab0d7593 100644 --- a/jaxley/pumps/__init__.py +++ b/jaxley/pumps/__init__.py @@ -0,0 +1,2 @@ +from jaxley.pumps.pump import Pump +from jaxley.pumps.ca_pump import CaPump diff --git a/jaxley/pumps/ca_pump.py b/jaxley/pumps/ca_pump.py index 73ebcc12..46a876f2 100644 --- a/jaxley/pumps/ca_pump.py +++ b/jaxley/pumps/ca_pump.py @@ -17,12 +17,17 @@ def __init__(self, name: Optional[str] = None): f"{self._name}_depth": 0.1, # Depth of shell in um. f"{self._name}_minCai": 1e-4, # Minimum intracell. ca concentration in mM. } - self.pump_states = {"CaCon_i": 5e-05} + self.pump_states = {} + self.ion_name = "CaCon_i" self.META = { "reference": "Modified from Destexhe et al., 1994", "mechanism": "Calcium dynamics", } + def update_states(self, u, dt, voltages, params): + """Update states if necessary (but this pump has no states to update).""" + return {"CaCon_i": u["CaCon_i"]} + def compute_current(self, u, dt, voltages, params): """Return change of calcium concentration based on calcium current and decay.""" prefix = self._name diff --git a/jaxley/pumps/pump.py b/jaxley/pumps/pump.py index 23176f99..72ebb5eb 100644 --- a/jaxley/pumps/pump.py +++ b/jaxley/pumps/pump.py @@ -58,16 +58,10 @@ def change_name(self, new_name: str): } return self - def update_states( - self, states, dt, v, params - ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: - """Return the updated states.""" - raise NotImplementedError - def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): - """Given channel states and voltage, return the current through the channel. + """Given channel states and voltage, return the change in ion concentration. Args: states: All states of the compartment. @@ -75,6 +69,6 @@ def compute_current( params: Parameters of the channel (conductances in `S/cm2`). Returns: - Current in `uA/cm2`. + Ion concentration change in `mM`. """ raise NotImplementedError