Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clamping of Synapse: Following Synapse might also be clamped #485

Closed
deezer257 opened this issue Nov 7, 2024 · 4 comments · Fixed by #492
Closed

Clamping of Synapse: Following Synapse might also be clamped #485

deezer257 opened this issue Nov 7, 2024 · 4 comments · Fixed by #492

Comments

@deezer257
Copy link
Contributor

deezer257 commented Nov 7, 2024

If I create a network where cells are connected via RibbonSynapses, and I apply a data clamp to one of the RibbonSynapses, not only is the current synapse clamped to the defined value, but also the next synapse in sequence (with an incremented index). Further, it doesn't matter which method I use to index the synapses (so if I use the pre or post synapsing indexing or the indexing via the edges). An example is attatched:

Jax version: 0.4.35

import matplotlib.pyplot as plt
import jaxley as jx
import jax
from jaxley_mech.synapses.ribbon import RibbonSynapse
from jaxley.connect import  connect
from jaxley_mech.channels.hodgkin52 import Leak, Na, K
import jax.numpy as jnp
import numpy as np


from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")

# Create a new dummy network
comp_opt = jx.Compartment()
branch_opt = jx.Branch(comp_opt, nseg=1)
cell_opt = jx.Cell([branch_opt], [-1])
net_opt = jx.Network(cells=[cell_opt] * 4)


# Insert the leak channel into the cell
net_opt.insert(Leak())
net_opt.insert(Na())
net_opt.insert(K())

v_rest = -70

# Set the resting potential of the cell and also to equilibrium of the
# leak channel to the resting potential, so that the overall resting
# potential of the cell is the resting potential
net_opt.set('v', v_rest)
net_opt.set('Leak_eLeak', v_rest)



# Define the pre, middle and post cell
first = net_opt.cell(0)
second = net_opt.cell(1)
third = net_opt.cell(2)
fourth = net_opt.cell(3)

# Connect the cells with the ribbon synapse
connect(first, second, RibbonSynapse(solver = "newton"))
connect(second, third, RibbonSynapse(solver = "newton"))
connect(third, fourth, RibbonSynapse(solver = "newton"))

net_opt.cell([0,1]).set('RibbonSynapse_V_half', v_rest/2)
net_opt.cell([1,2]).set('RibbonSynapse_V_half', v_rest/2)
net_opt.cell([2,3]).set('RibbonSynapse_V_half', v_rest/2)
net_opt.cell('all').set('RibbonSynapse_gS', 0)

# Parameters
time_max = 1000
dt = 1
time_steps = int((time_max + dt)  / dt)

# Time vector
time_vec = jnp.arange(0.0, time_max, dt)

# Generate a base signal with a Gaussian function
mean = (time_max / 2) - 100  # Center of the Gaussian
std_dev = 100  # Standard deviation of the Gaussian
base_signal = (np.exp(-0.5 * ((time_vec - mean) / std_dev) ** 2)) * 3 +2

# Get a 2d array of inputs by putting the base signal in the first column
# and the base signal shifted by 1 in the second column
inputs = jnp.array([time_vec, base_signal]).T

# Integrate the network without using vmap and jit
net_opt.delete_recordings()
net_opt.delete_stimuli()

#net_opt.cell([0,1]).record("RibbonSynapse_exo")
#net_opt.cell([1,2]).record("RibbonSynapse_exo")
#net_opt.cell([2,3]).record("RibbonSynapse_exo")

net_opt.RibbonSynapse.edge(0).record("RibbonSynapse_exo", verbose = False)
net_opt.RibbonSynapse.edge(1).record("RibbonSynapse_exo", verbose = False)
net_opt.RibbonSynapse.edge(2).record("RibbonSynapse_exo", verbose = False)

net_opt.cell(0).record("v")
net_opt.cell(1).record("v")
net_opt.cell(2).record("v")
net_opt.cell(3).record("v")

data_clamps = None
# Input are the y-values of the inputs
#data_clamps = net_opt.cell([0,1]).data_clamp("RibbonSynapse_exo", inputs[:,1], data_clamps = data_clamps)
data_clamps = net_opt.edge(0).data_clamp("RibbonSynapse_exo", inputs[:,1], data_clamps = data_clamps)

# Integrate the network
s = jx.integrate(net_opt, 
                data_clamps = data_clamps,
                solver = "bwd_euler")

fig, ax = plt.subplots(s.shape[0], 1, figsize=(10, 20))
# Increase space between subplots
plt.subplots_adjust(hspace=1)

# Loop over the subplots and plot each synapse
for i in range(s.shape[0]):
    ax[i].plot(s[i, :])
    if i < 3:
        ax[i].set_title(f"Synapse {i + 1}")
    else:
        ax[i].set_title(f"Mmebrance Voltage Cell {i - 3}")
        

image

@deezer257 deezer257 reopened this Nov 7, 2024
@jnsbck
Copy link
Contributor

jnsbck commented Nov 7, 2024

Hey, thanks for reporting this. Three things that would be really helpful to clarify first:

  1. are you on the most recent commit?
  2. Did you check if this also applies to other synapses, i.e. IonotropicSynapse? If this is not the case I suspect this might be an issue in jaxley_mech.
  3. If 1. and 2. are not the case, could you come up with a more minimal example, i.e. rm everything that's not strictly necessary but still produces this behavior. That would be really helpful to debug this.

@deezer257
Copy link
Contributor Author

deezer257 commented Nov 7, 2024

  1. I git cloned the repository from the main branch and pip installed it in the editable mode on 07.11.2024 at 16:00, so yes. Still there was the same problem with the RibbonSynapse.

  2. I tried to clamp the parameters of the IonotropicSynapse, but the library doesn't accept that in the data_clamp:

  3. A shorter MWE for IonotropicSynapse:

import matplotlib.pyplot as plt
import jaxley as jx
import jax
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import connect
from jaxley_mech.channels.hodgkin52 import Leak, Na, K
import jax.numpy as jnp
import numpy as np

# Create a new dummy network
comp_opt = jx.Compartment()
branch_opt = jx.Branch(comp_opt, nseg=1)
cell_opt = jx.Cell([branch_opt], [-1])
net_opt = jx.Network(cells=[cell_opt] * 4)

# Insert the leak channel into the cell
net_opt.insert(Leak())
net_opt.insert(Na())
net_opt.insert(K())

# Connect the cells with the synapses
connect(net_opt.cell(0),  net_opt.cell(1), IonotropicSynapse())
connect(net_opt.cell(1), net_opt.cell(2), IonotropicSynapse())
connect(net_opt.cell(2), net_opt.cell(3), IonotropicSynapse())

# Set the conductance of the synapse to zero
net_opt.cell('all').set('IonotropicSynapse_gS', 0)

inputs = jnp.ones(1000)

net_opt.delete_recordings()
net_opt.delete_stimuli()

# Record the conductance of the synapses
net_opt.cell([0,1]).record("IonotropicSynapse_gS")
net_opt.cell([1,2]).record("IonotropicSynapse_gS")
net_opt.cell([2,3]).record("IonotropicSynapse_gS")

# Clamp the cell
data_clamps = None
data_clamps = net_opt.cell([0,1]).data_clamp("IonotropicSynapse_gS", inputs, data_clamps = data_clamps)

# Integrate the network
s = jx.integrate(net_opt, 
                 data_clamps = data_clamps,
                 solver = "bwd_euler")

fig, ax = plt.subplots(2, 1, figsize=(5, 5))
ax[0].plot(s[0, :])
ax[1].plot(s[1, :])

This always yields the error, which doesn't make so much sense for me since I was able to set IonotropicSynapse_gS to 0 before:
image

I did exactly the same with the RibbonSynapse and there the data_clamp worked.

  1. This would be the shorte MWE for the RibbonSynapse:
import matplotlib.pyplot as plt
import jaxley as jx
import jax
from jaxley_mech.synapses.ribbon import RibbonSynapse
from jaxley.connect import  connect
from jaxley_mech.channels.hodgkin52 import Leak, Na, K
import jax.numpy as jnp
import numpy as np


from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "gpu")

# Create a new dummy network
comp_opt = jx.Compartment()
branch_opt = jx.Branch(comp_opt, nseg=1)
cell_opt = jx.Cell([branch_opt], [-1])
net_opt = jx.Network(cells=[cell_opt] * 4)


# Insert the leak channel into the cell
net_opt.insert(Leak())
net_opt.insert(Na())
net_opt.insert(K())


# Connect the cells with the ribbon synapses
connect(net_opt.cell(0), net_opt.cell(1), RibbonSynapse(solver = "newton"))
connect(net_opt.cell(1), net_opt.cell(2), RibbonSynapse(solver = "newton"))
connect(net_opt.cell(2), net_opt.cell(3), RibbonSynapse(solver = "newton"))

# Set the conductance of the synapse to zero
net_opt.cell('all').set('RibbonSynapse_gS', 0)


inputs = jnp.ones(1000)

net_opt.delete_recordings()
net_opt.delete_stimuli()

# Record the conductance of the synapses
net_opt.cell([0,1]).record("RibbonSynapse_exo")
net_opt.cell([1,2]).record("RibbonSynapse_exo")
net_opt.cell([2,3]).record("RibbonSynapse_exo")

# Clamp the cell
data_clamps = None
data_clamps = net_opt.cell([0,1]).data_clamp("RibbonSynapse_exo", inputs, data_clamps = data_clamps)


# Integrate the network
s = jx.integrate(net_opt, 
                data_clamps = data_clamps,
                solver = "bwd_euler")

fig, ax = plt.subplots(2, 1, figsize=(5, 5))
ax[0].plot(s[0, :])
ax[1].plot(s[1, :])

The same problem arose, with the data clamped synapse. The figure depicts the recorded parameters of the synapse.
image

@deezer257
Copy link
Contributor Author

I think I found the origin of the problem which it is more like an inconsistency. If I use data clamp you have to use the index of the pre synaptic cell to clamp the according Synapse, e.g. cell(pre_synaptic). However, since the synapsing indexing was changed to something like cell[pre_synaotic_cell, post_synaptic_cell] and the data_clamp function remained the old functionality, this might
caused the problem.

@jnsbck jnsbck linked a pull request Nov 8, 2024 that will close this issue
@jnsbck
Copy link
Contributor

jnsbck commented Nov 8, 2024

Thanks for checking up on this, I found the likely culprit. Will fix asap. You can keep track of the progress in #492.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants