Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ion diffusion #438

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 130 additions & 47 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from matplotlib.axes import Axes

from jaxley.channels import Channel
from jaxley.pumps import Pump
from jaxley.solver_voltage import (
step_voltage_explicit,
step_voltage_implicit_with_jax_spsolve,
Expand Down Expand Up @@ -45,6 +46,9 @@ class Module(ABC):

This base class defines the scaffold for all jaxley modules (compartments,
branches, cells, networks).

Note that the `__init__()` method is not abstract. This is because each module
type has a different initialization procedure.
"""

def __init__(self):
Expand Down Expand Up @@ -91,6 +95,9 @@ def __init__(self):
self.channels: List[Channel] = []
self.membrane_current_names: List[str] = []

# List of all states (exluding voltage) that are being diffused.
self.diffusion_states: List[str] = []

# For trainable parameters.
self.indices_set_by_trainables: List[jnp.ndarray] = []
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
Expand Down Expand Up @@ -287,8 +294,12 @@ def _init_morph_jaxley_spsolve(self):
raise NotImplementedError

def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):
"""Given radius, length, r_a, compute the axial coupling conductances."""
return compute_axial_conductances(self._comp_edges, params)
"""Given radius, length, r_a, compute the axial coupling conductances.

If ion diffusion was activated by the user (with `cell.diffuse()`) then this
function also compute the axial conductances for every ion.
"""
return compute_axial_conductances(self._comp_edges, params, self.diffusion_states)

def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"):
"""Adds channel nodes from constituents to `self.channel_nodes`."""
Expand All @@ -313,6 +324,26 @@ def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"):
for key in channel.channel_states:
self.nodes.loc[view.index.values, key] = channel.channel_states[key]

def _append_pump_to_nodes(self, view: pd.DataFrame, pump: "jx.Pump"):
"""Adds pump nodes from constituents to `self.pump_nodes`."""
name = pump._name

# Pump does not yet exist in the `jx.Module` at all.
if name not in [c._name for c in self.pumps]:
self.pumps.append(pump)
self.nodes[name] = False # Previous columns do not have the new pump.

# Add a binary column that indicates if the pump is present.
self.nodes.loc[view.index.values, name] = True

# Loop over all new parameters.
for key in pump.pump_params:
self.nodes.loc[view.index.values, key] = pump.pump_params[key]

# Loop over all new states.
for key in pump.pump_states:
self.nodes.loc[view.index.values, key] = pump.pump_states[key]

def set(self, key: str, val: Union[float, jnp.ndarray]):
"""Set parameter of module (or its view) to a new value.

Expand Down Expand Up @@ -395,6 +426,26 @@ def _data_set(
raise KeyError("Key not recognized.")
return param_state

def diffuse(self, state: str):
"""Diffuse a particular state across compartments with Fickian diffusion.

Args:
state: Name of the state that should be diffused.
"""
self._diffuse(state, self.nodes, self.nodes)

def _diffuse(self, state: str, table_to_update: pd.DataFrame, view: pd.DataFrame):
self.diffusion_states.append(state)
table_to_update.loc[view.index.values, f"axial_resistivity_{state}"] = 1.0

# The diffused state might not exist in all compartments that across which
# we are diffusing (e.g. there are active calcium mechanisms only in the soma,
# but calcium should still diffuse into the dendrites). Here, we ensure that
# the state is not `NaN` in every compartment across which we are diffusing.
state_is_nan = pd.isna(view[state])
average_state_value = view[state].mean()
table_to_update.loc[state_is_nan, state] = average_state_value

def make_trainable(
self,
key: str,
Expand Down Expand Up @@ -548,6 +599,11 @@ def get_all_parameters(
for key in ["radius", "length", "axial_resistivity", "capacitance"]:
params[key] = self.jaxnodes[key]

for key in self.diffusion_states:
params[f"axial_resistivity_{key}"] = self.jaxnodes[
f"axial_resistivity_{key}"
]

for channel in self.channels:
for channel_params in channel.channel_params:
params[channel_params] = self.jaxnodes[channel_params]
Expand Down Expand Up @@ -883,6 +939,16 @@ def insert(self, channel: Channel):
def _insert(self, channel, view):
self._append_channel_to_nodes(view, channel)

def pump(self, pump: Pump):
"""Insert a pump into the module.

Args:
pump: The pump to insert."""
self._pump(pump, self.nodes)

def _pump(self, pump, view):
self._append_pump_to_nodes(view, pump)

def init_syns(self):
self.initialized_syns = True

Expand All @@ -900,7 +966,7 @@ def step(

This function is called inside of `integrate` and increments the state of the
module by one time step. Calls `_step_channels` and `_step_synapse` to update
the states of the channels and synapses using fwd_euler.
the states of the channels and synapses.

Args:
u: The state of the module. voltages = u["v"]
Expand Down Expand Up @@ -934,6 +1000,11 @@ def step(
u, delta_t, self.channels, self.nodes, params
)

# # Step of the Pumps.
# u, (v_terms, const_terms) = self._step_pumps(
# u, delta_t, self.pumps, self.nodes, params
# )

# Step of the synapse.
u, (syn_v_terms, syn_const_terms) = self._step_synapse(
u,
Expand All @@ -952,25 +1023,29 @@ def step(
cm = params["capacitance"] # Abbreviation.

# Arguments used by all solvers.
solver_kwargs = {
"voltages": voltages,
"voltage_terms": (v_terms + syn_v_terms) / cm,
"constant_terms": (const_terms + i_ext + syn_const_terms) / cm,
num_diffused_states = len(self.diffusion_states)
diffused_state_zeros = [jnp.zeros_like(v_terms)] * num_diffused_states
state_vals = {
"voltages": jnp.stack([voltages] + [u[d] for d in self.diffusion_states]),
"voltage_terms": jnp.stack(
[(v_terms + syn_v_terms) / cm] + diffused_state_zeros
),
"constant_terms": jnp.stack(
[(const_terms + i_ext + syn_const_terms) / cm] + diffused_state_zeros
),
"axial_conductances": params["axial_conductances"],
"internal_node_inds": self._internal_node_inds,
}

# Add solver specific arguments.
if voltage_solver == "jax.sparse":
solver_kwargs.update(
{
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"data_inds": self._data_inds,
"indices": self._indices_jax_spsolve,
"indptr": self._indptr_jax_spsolve,
"n_nodes": self._n_nodes,
}
)
solver_kwargs = {
"data_inds": self._data_inds,
"indices": self._indices_jax_spsolve,
"indptr": self._indptr_jax_spsolve,
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"n_nodes": self._n_nodes,
"internal_node_inds": self._internal_node_inds,
}
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_jax_spsolve
else:
Expand All @@ -980,42 +1055,51 @@ def step(
# Currently, the forward Euler solver also uses this format. However,
# this is only for historical reasons and we are planning to change this in
# the future.
solver_kwargs.update(
{
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"sources": np.asarray(self._comp_edges["source"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"masked_node_inds": self._remapped_node_indices,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"debug_states": self.debug_states,
}
)
solver_kwargs = {
"internal_node_inds": self._internal_node_inds,
"sinks": np.asarray(self._comp_edges["sink"].to_list()),
"sources": np.asarray(self._comp_edges["source"].to_list()),
"types": np.asarray(self._comp_edges["type"].to_list()),
"masked_node_inds": self._remapped_node_indices,
"nseg_per_branch": self.nseg_per_branch,
"nseg": self.nseg,
"par_inds": self.par_inds,
"child_inds": self.child_inds,
"nbranches": self.total_nbranches,
"solver": voltage_solver,
"children_in_level": self.children_in_level,
"parents_in_level": self.parents_in_level,
"root_inds": self.root_inds,
"branchpoint_group_inds": self.branchpoint_group_inds,
"debug_states": self.debug_states,
}
# Only for `bwd_euler` and `cranck-nicolson`.
step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve

if solver == "bwd_euler":
u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)
nones = [None] * len(solver_kwargs)
vmapped = vmap(step_voltage_implicit, in_axes=(0, 0, 0, 0, *nones, None))
updated_states = vmapped(
*state_vals.values(), *solver_kwargs.values(), delta_t
)
u["v"] = updated_states[0]
for i, diffusion_state in enumerate(self.diffusion_states):
# +1 because voltage is the zero-eth element.
u[diffusion_state] = updated_states[i + 1]
elif solver == "crank_nicolson":
# Crank-Nicolson advances by half a step of backward and half a step of
# forward Euler.
half_step_delta_t = delta_t / 2
half_step_voltages = step_voltage_implicit(
**solver_kwargs, delta_t=half_step_delta_t
**state_vals, **solver_kwargs, delta_t=half_step_delta_t
)
# The forward Euler step in Crank-Nicolson can be performed easily as
# `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.
u["v"] = 2 * half_step_voltages - voltages
elif solver == "fwd_euler":
u["v"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)
u["v"] = step_voltage_explicit(
**state_vals, **solver_kwargs, delta_t=delta_t
)
else:
raise ValueError(
f"You specified `solver={solver}`. The only allowed solvers are "
Expand Down Expand Up @@ -1687,14 +1771,13 @@ def data_set(
"""Set parameter of module (or its view) to a new value within `jit`."""
return self.pointer._data_set(key, val, self.view, param_state)

def make_trainable(
self,
key: str,
init_val: Optional[Union[float, list]] = None,
verbose: bool = True,
):
"""Make a parameter trainable."""
self.pointer._make_trainable(self.view, key, init_val, verbose=verbose)
def diffuse(self, state: str):
"""Diffuse a particular state across compartments with Fickian diffusion.

Args:
state: Name of the state that should be diffused.
"""
self._diffuse(state, self.pointer.nodes, self.view)

def add_to_group(self, group_name: str):
self.pointer._add_to_group(group_name, self.view)
Expand Down
2 changes: 2 additions & 0 deletions jaxley/pumps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from jaxley.pumps.pump import Pump
from jaxley.pumps.ca_pump import CaPump
44 changes: 44 additions & 0 deletions jaxley/pumps/ca_pump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Optional

from jaxley.pumps.pump import Pump


class CaPump(Pump):
"""Calcium dynamics tracking inside calcium concentration

Modeled after Destexhe et al. 1994.
"""

def __init__(self, name: Optional[str] = None):
super().__init__(name)
self.pump_params = {
f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered).
f"{self._name}_decay": 80, # Buffering time constant in ms.
f"{self._name}_depth": 0.1, # Depth of shell in um.
f"{self._name}_minCai": 1e-4, # Minimum intracell. ca concentration in mM.
}
self.pump_states = {}
self.ion_name = "CaCon_i"
self.META = {
"reference": "Modified from Destexhe et al., 1994",
"mechanism": "Calcium dynamics",
}

def update_states(self, u, dt, voltages, params):
"""Update states if necessary (but this pump has no states to update)."""
return {"CaCon_i": u["CaCon_i"]}

def compute_current(self, u, dt, voltages, params):
"""Return change of calcium concentration based on calcium current and decay."""
prefix = self._name
ica = u["i_Ca"] / 1_000.0
gamma = params[f"{prefix}_gamma"]
decay = params[f"{prefix}_decay"]
depth = params[f"{prefix}_depth"]
minCai = params[f"{prefix}_minCai"]

FARADAY = 96485 # Coulombs per mole.

# Calculate the contribution of calcium currents to cai change.
drive_channel = -10_000.0 * ica * gamma / (2 * FARADAY * depth)
return drive_channel - (u["CaCon_i"] + minCai) / decay
Loading