Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connectivity_matrix_connect update #489

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
173 changes: 105 additions & 68 deletions jaxley/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,19 @@ def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentVi
return np.random.choice(cell_view._comps_in_view, num, replace=replace)


def get_random_post_comps(post_cell_view: "View", num_post: int) -> "CompartmentView":
"""Sample global compartment indices from all postsynaptic cells."""
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_post, replace=True)
.index.to_numpy()
)
global_post_comp_indices = global_post_comp_indices.reshape(
(-1, num_post), order="F"
).ravel()
return global_post_comp_indices


def connect(
pre: "View",
post: "View",
Expand All @@ -44,33 +57,40 @@ def fully_connect(
pre_cell_view: "View",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename fully_connect to fully_connect_cells? Might be less confusing, since otherwise connections are made comp 2 comp

post_cell_view: "View",
synapse_type: "Synapse",
random_post_comp: bool = False,
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
):
"""Appends multiple connections which build a fully connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.

Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
random_post_comp: If True, randomly samples the postsynaptic compartments.
"""
# Get pre- and postsynaptic cell indices.
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

# Infer indices of (random) postsynaptic compartments.
global_post_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel()
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
# Repeat rows `num_post` times. See SO 50788508.
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)
# Get a view of the zeroeth compartment of each cell as the pre compartments
pre_comps = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
# Repeat rows `num_post` times
pre_rows = pre_comps.loc[pre_comps.index.repeat(num_post)].reset_index(drop=True)

if random_post_comp:
global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre)
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
to_idx = np.tile(range(0, num_post), num_pre)
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index").first()[
"global_comp_index"
]
).to_numpy()
global_post_comp_indices = global_post_comp_indices[to_idx]
post_rows = post_cell_view.nodes.loc[global_post_comp_indices]

pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)

Expand All @@ -80,45 +100,62 @@ def sparse_connect(
post_cell_view: "View",
synapse_type: "Synapse",
p: float,
random_post_comp: bool = False,
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
):
"""Appends multiple connections which build a sparse, randomly connected layer.

Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.

Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
p: Probability of connection.
random_post_comp: If True, randomly samples the postsynaptic compartments.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)

num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)

# Sort the synapses only for convenience of inspecting `.edges`.
sorting = np.argsort(pre_syn_neurons)
pre_syn_neurons = pre_syn_neurons[sorting]
post_syn_neurons = post_syn_neurons[sorting]

# Post-synapse is a randomly chosen branch and compartment.
global_post_indices = [
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_syn_neurons
]
global_post_indices = (
np.hstack(global_post_indices) if len(global_post_indices) > 1 else []
)
post_rows = post_cell_view.base.nodes.loc[global_post_indices]
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]
# Generate random cxns via Bernoulli trials (no duplicates), done in blocks of the
# connectivity matrix to save memory and time (smaller cut size saves memory,
# larger saves time)
cut_size = 100 # --> (100, 100) dim blocks
pre_inds, post_inds = [], []
for i in range((num_pre + cut_size - 1) // cut_size):
for j in range((num_post + cut_size - 1) // cut_size):
block = np.random.binomial(1, p, size=(cut_size, cut_size))
block_pre, block_post = np.where(block)
block_pre += i * cut_size # block inds --> full adj mat inds
block_post += j * cut_size # block inds --> full adj mat inds
pre_inds.append(block_pre)
post_inds.append(block_post)
pre_post_inds = np.stack(
(np.concatenate(pre_inds), np.concatenate(post_inds)), axis=1
)
# Filter out connections where either pre or post index is out of range
pre_post_inds = pre_post_inds[
(pre_post_inds[:, 0] < num_pre) & (pre_post_inds[:, 1] < num_post)
]
from_idx, to_idx = pre_post_inds[:, 0], pre_post_inds[:, 1]

# Pre-synapse at the zero-eth branch and zero-eth compartment
global_pre_comp_indices = (
pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"]
).to_numpy()
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes

if random_post_comp:
global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre)
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index").first()[
"global_comp_index"
]
).to_numpy()
post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes

if len(pre_rows) > 0:
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
Expand All @@ -129,49 +166,49 @@ def connectivity_matrix_connect(
post_cell_view: "View",
synapse_type: "Synapse",
connectivity_matrix: np.ndarray[bool],
random_post_comp: bool = False,
kyralianaka marked this conversation as resolved.
Show resolved Hide resolved
):
"""Appends multiple connections which build a custom connected network.
"""Appends multiple connections according to a custom connectivity matrix.

Connects pre- and postsynaptic cells according to a custom connectivity matrix.
Entries > 0 in the matrix indicate a connection between the corresponding cells.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.

Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
connectivity_matrix: A boolean matrix indicating the connections between cells.
random_post_comp: If True, randomly samples the postsynaptic compartments.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
# setting scope ensure that this works indep of current scope
pre_nodes = pre_cell_view.scope("local").branch(0).comp(0).nodes
pre_nodes["index"] = pre_nodes.index
pre_cell_nodes = pre_nodes.set_index("global_cell_index")
# Get pre- and postsynaptic cell indices
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)

assert connectivity_matrix.shape == (
len(pre_cell_inds),
len(post_cell_inds),
num_pre,
num_post,
), "Connectivity matrix must have shape (num_pre, num_post)."
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean."

# get connection pairs from connectivity matrix
# Get pre to post connection pairs from connectivity matrix
from_idx, to_idx = np.where(connectivity_matrix)
pre_cell_inds = pre_cell_inds[from_idx]
post_cell_inds = post_cell_inds[to_idx]

# Sample random postsynaptic compartments (global comp indices).
global_post_indices = np.hstack(
[
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_cell_inds
]
)
post_rows = post_cell_view.nodes.loc[global_post_indices]

# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, "index"].to_numpy()
pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes
# Pre-synapse at the zero-eth branch and zero-eth compartment
global_pre_comp_indices = (
pre_cell_view.nodes.groupby("global_cell_index").first()["global_comp_index"]
).to_numpy()
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes

if random_post_comp:
global_post_comp_indices = get_random_post_comps(post_cell_view, num_pre)
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index").first()[
"global_comp_index"
]
).to_numpy()
post_rows = post_cell_view.select(nodes=global_post_comp_indices[to_idx]).nodes

pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
8 changes: 4 additions & 4 deletions tests/jaxley_identical/test_basic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,13 @@ def test_complex_net(voltage_solver, SimpleNet):
_ = np.random.seed(0)
pre = net.cell([0, 1, 2])
post = net.cell([3, 4, 5])
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)

pre = net.cell([3, 4, 5])
post = net.cell(6)
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)

area = 2 * pi * 10.0 * 1.0
point_process_to_dist_factor = 100_000.0 / area
Expand Down
8 changes: 4 additions & 4 deletions tests/jaxley_identical/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def test_network_grad(SimpleNet):
_ = np.random.seed(0)
pre = net.cell([0, 1, 2])
post = net.cell([3, 4, 5])
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a test for random_post_comp=False? At least for fully_connect, but also for the others? At least to check that they do not break?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the tests in test_connection.py are with random_post_comp=False, but the tests that used connection.py everywhere else (test_grad.py and test_basic_modules.py) use random_post_comp=True with fully connect (so that the simulation results are the same as before). I could add tests for random_post_comp=True to test_connection.py -- would this then be enough coverage?


pre = net.cell([3, 4, 5])
post = net.cell(6)
fully_connect(pre, post, IonotropicSynapse())
fully_connect(pre, post, TestSynapse())
fully_connect(pre, post, IonotropicSynapse(), random_post_comp=True)
fully_connect(pre, post, TestSynapse(), random_post_comp=True)

area = 2 * pi * 10.0 * 1.0
point_process_to_dist_factor = 100_000.0 / area
Expand Down
Loading