Skip to content

Commit

Permalink
Small refactor of plotting (#539)
Browse files Browse the repository at this point in the history
* enh: refactor of network plotting and plotting kwargs

* fix: make tests pass

* fix: rm nx dep

* enh: ref single point computation

* fix: rm networkx dep, update changelog

* fix: fix typo
  • Loading branch information
jnsbck authored Dec 3, 2024
1 parent dddd1b7 commit a22ee42
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 175 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ net.record("i_IonotropicSynapse")
- Regression tests can be done locally by running `NEW_BASELINE=1 pytest -m regression` i.e. on `main` and then `pytest -m regression` on `feature`, which will produce a test report (printed to the console and saved to .txt).
- If a PR introduces new baseline tests or reduces runtimes, then a new baseline can be created by commenting "/update_regression_baselines" on the PR.

- refactor plotting (#539, @jnsbck).
- rm networkx dependency
- add `Network.arrange_in_layers`
- disentangle moving of cells and plotting in `Network.vis`. To get the same as `net.vis(layers=[3,3])`, one now has to do:
```python
net.arrange_in_layers([3,3])
net.vis()
```

# 0.5.0

### API changes
Expand Down
23 changes: 14 additions & 9 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,10 +2086,10 @@ def _get_external_input(
def vis(
self,
ax: Optional[Axes] = None,
col: str = "k",
color: str = "k",
dims: Tuple[int] = (0, 1),
type: str = "line",
morph_plot_kwargs: Dict = {},
**kwargs,
) -> Axes:
"""Visualize the module.
Expand All @@ -2102,24 +2102,29 @@ def vis(
- `scatter`: All traced points, are plotted as scatter points.
- `comp`: Plots the compartmentalized morphology, including radius
and shape. (shows the true compartment lengths per default, but this can
be changed via the `morph_plot_kwargs`, for details see
be changed via the `kwargs`, for details see
`jaxley.utils.plot_utils.plot_comps`).
- `morph`: Reconstructs the 3D shape of the traced morphology. For details see
`jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies
with many traced points this can be very slow.
Args:
ax: An axis into which to plot.
col: The color for all branches.
color: The color for all branches.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
type: The type of plot. One of ["line", "scatter", "comp", "morph"].
morph_plot_kwargs: Keyword arguments passed to the plotting function.
kwargs: Keyword arguments passed to the plotting function.
"""
res = 100 if "resolution" not in kwargs else kwargs.pop("resolution")
if "comp" in type.lower():
return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)
return plot_comps(
self, dims=dims, ax=ax, color=color, resolution=res, **kwargs
)
if "morph" in type.lower():
return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)
return plot_morph(
self, dims=dims, ax=ax, color=color, resolution=res, **kwargs
)

assert not np.any(
[np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]
Expand All @@ -2128,10 +2133,10 @@ def vis(
ax = plot_graph(
self.xyzr,
dims=dims,
col=col,
color=color,
ax=ax,
type=type,
morph_plot_kwargs=morph_plot_kwargs,
**kwargs,
)

return ax
Expand Down
217 changes: 87 additions & 130 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import itertools
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
from warnings import warn

import jax.numpy as jnp
import networkx as nx
import numpy as np
import pandas as pd
from jax import vmap
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

from jaxley.modules.base import Module
Expand Down Expand Up @@ -383,171 +384,127 @@ def _synapse_currents(

return states, (syn_voltage_terms, syn_constant_terms)

def arrange_in_layers(
self,
layers: List[int],
within_layer_offset: float = 500.0,
between_layer_offset: float = 1500.0,
vertical_layers: bool = False,
):
"""Arrange the cells in the network to form layers.
Moves the cells in the network to arrange them into layers.
Args:
layers: List of integers specifying the number of cells in each layer.
within_layer_offset: Offset between cells within the same layer.
between_layer_offset: Offset between layers.
vertical_layers: If True, layers are arranged vertically.
"""
assert (
np.sum(layers) == self.shape[0]
), "The number of cells in the layers must match the number of cells in the network."
cells_in_layers = [
list(range(sum(layers[:i]), sum(layers[: i + 1])))
for i in range(len(layers))
]

for l, cell_inds in enumerate(cells_in_layers):
layer = self.cell(cell_inds)
for i, cell in enumerate(layer.cells):
if vertical_layers:
x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset
y_offset = (len(layers) - 1 - l) * between_layer_offset
else:
x_offset = l * between_layer_offset
y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset

cell.move_to(x=x_offset, y=y_offset, z=0)

def vis(
self,
detail: str = "full",
ax: Optional[Axes] = None,
col: str = "k",
synapse_col: str = "b",
color: str = "k",
synapse_color: str = "b",
dims: Tuple[int] = (0, 1),
type: str = "line",
layers: Optional[List] = None,
morph_plot_kwargs: Dict = {},
cell_plot_kwargs: Dict = {},
synapse_plot_kwargs: Dict = {},
synapse_scatter_kwargs: Dict = {},
networkx_options: Dict = {},
layer_kwargs: Dict = {},
) -> Axes:
"""Visualize the module.
Args:
detail: Either of [point, full]. `point` visualizes every neuron in the
network as a dot (and it uses `networkx` to obtain cell positions).
network as a dot.
`full` plots the full morphology of every neuron. It requires that
`compute_xyz()` has been run and allows for indivual neurons to be
moved with `.move()`.
col: The color in which cells are plotted. Only takes effect if
color: The color in which cells are plotted. Only takes effect if
`detail='full'`.
type: Either `line` or `scatter`. Only takes effect if `detail='full'`.
synapse_col: The color in which synapses are plotted. Only takes effect if
synapse_color: The color in which synapses are plotted. Only takes effect if
`detail='full'`.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
layers: Allows to plot the network in layers. Should provide the number of
neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input
neurons, 10 hidden layer neurons, and 1 output neuron.
morph_plot_kwargs: Keyword arguments passed to the plotting function for
cell_plot_kwargs: Keyword arguments passed to the plotting function for
cell morphologies. Only takes effect for `detail='full'`.
synapse_plot_kwargs: Keyword arguments passed to the plotting function for
syanpses. Only takes effect for `detail='full'`.
synapse_scatter_kwargs: Keyword arguments passed to the scatter function
for the end point of synapses. Only takes effect for `detail='full'`.
networkx_options: Options passed to `networkx.draw()`. Only takes effect if
`detail='point'`.
layer_kwargs: Only used if `layers` is specified and if `detail='full'`.
Can have the following entries: `within_layer_offset` (float),
`between_layer_offset` (float), `vertical_layers` (bool).
"""
if detail == "point":
graph = self._build_graph(layers)
xyz0 = self.cell(0).xyzr[0][:, :3]
same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])
if same_xyz:
warn(
"Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them."
)

if ax is None:
fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d")

if layers is not None:
pos = nx.multipartite_layout(graph, subset_key="layer")
nx.draw(graph, pos, with_labels=True, **networkx_options)
else:
nx.draw(graph, with_labels=True, **networkx_options)
# detail="point" -> pos taken to be the mean of all traced points on the cell.
cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0)

dims_np = np.asarray(dims)
if detail == "point":
for cell in self.cells:
pos = cell_to_point_xyz(cell)[dims_np]
ax.scatter(*pos, color=color, **cell_plot_kwargs)
elif detail == "full":
if layers is not None:
# Assemble cells in the network into layers.
global_counter = 0
layers_config = {
"within_layer_offset": 500.0,
"between_layer_offset": 1500.0,
"vertical_layers": False,
}
layers_config.update(layer_kwargs)
for layer_ind, num_in_layer in enumerate(layers):
for ind_within_layer in range(num_in_layer):
if layers_config["vertical_layers"]:
x_offset = (
ind_within_layer - (num_in_layer - 1) / 2
) * layers_config["within_layer_offset"]
y_offset = (len(layers) - 1 - layer_ind) * layers_config[
"between_layer_offset"
]
else:
x_offset = layer_ind * layers_config["between_layer_offset"]
y_offset = (
ind_within_layer - (num_in_layer - 1) / 2
) * layers_config["within_layer_offset"]

self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)
global_counter += 1
ax = super().vis(
dims=dims,
col=col,
ax=ax,
type=type,
morph_plot_kwargs=morph_plot_kwargs,
dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs
)
else:
raise ValueError("detail must be in {full, point}.")

pre_locs = self.edges["pre_locs"].to_numpy()
post_locs = self.edges["post_locs"].to_numpy()
pre_comp = self.edges["pre_global_comp_index"].to_numpy()
nodes = self.nodes.set_index("global_comp_index")
pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy()
post_comp = self.edges["post_global_comp_index"].to_numpy()
post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy()

dims_np = np.asarray(dims)

for pre_loc, post_loc, pre_b, post_b in zip(
pre_locs, post_locs, pre_branch, post_branch
):
pre_coord = self.xyzr[pre_b]
if len(pre_coord) == 2:
# If only start and end point of a branch are traced, perform a
# linear interpolation to get the synpase location.
pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc
else:
# If densely traced, use intermediate trace values for synapse loc.
middle_ind = int((len(pre_coord) - 1) * pre_loc)
pre_coord = pre_coord[middle_ind]

post_coord = self.xyzr[post_b]
if len(post_coord) == 2:
nodes = self.nodes.set_index("global_comp_index")
for i, edge in self.edges.iterrows():
prepost_locs = []
for prepost in ["pre", "post"]:
loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]]
branch = nodes.loc[comp, "global_branch_index"]
cell = nodes.loc[comp, "global_cell_index"]
branch_xyz = self.xyzr[branch]

xyz_loc = branch_xyz
if detail == "point":
xyz_loc = cell_to_point_xyz(self.cell(cell))
elif len(branch_xyz) == 2:
# If only start and end point of a branch are traced, perform a
# linear interpolation to get the synpase location.
post_coord = (
post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc
)
xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc
else:
# If densely traced, use intermediate trace values for synapse loc.
middle_ind = int((len(post_coord) - 1) * post_loc)
post_coord = post_coord[middle_ind]

coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T
ax.plot(
coords[0],
coords[1],
c=synapse_col,
**synapse_plot_kwargs,
)
ax.scatter(
post_coord[dims_np[0]],
post_coord[dims_np[1]],
c=synapse_col,
**synapse_scatter_kwargs,
)
else:
raise ValueError("detail must be in {full, point}.")
middle_ind = int((len(branch_xyz) - 1) * loc)
xyz_loc = xyz_loc[middle_ind]

return ax

def _build_graph(self, layers: Optional[List] = None, **options):
graph = nx.DiGraph()

def build_extents(*subset_sizes):
return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))

if layers is not None:
extents = build_extents(*layers)
layers = [range(start, end) for start, end in extents]
for i, layer in enumerate(layers):
graph.add_nodes_from(layer, layer=i)
else:
graph.add_nodes_from(range(len(self._cells_in_view)))

pre_comp = self.edges["pre_global_comp_index"].to_numpy()
nodes = self.nodes.set_index("global_comp_index")
pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy()
post_comp = self.edges["post_global_comp_index"].to_numpy()
post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy()
prepost_locs.append(xyz_loc)
prepost_locs = np.stack(prepost_locs).T

inds = np.stack([pre_cell, post_cell]).T
graph.add_edges_from(inds)
ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)

return graph
return ax

def _infer_synapse_type_ind(self, synapse_name):
syn_names = self.base.synapse_names
Expand Down
Loading

0 comments on commit a22ee42

Please sign in to comment.