Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Vectorize recording during integrate #561

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ntolley
Copy link
Contributor

@ntolley ntolley commented Dec 20, 2024

While testing the recording from many many states, I noticed I was experiencing some serious performance hits during simulation. @jnsbck suggested that this may be due to a for loop over recordings that occurs during the integrate call. This is an attempt to vectorize that update only indexing each unique state once.

jaxley/integrate.py Outdated Show resolved Hide resolved
@ntolley
Copy link
Contributor Author

ntolley commented Dec 20, 2024

Here's some code to see how recordings impact speed: the comparison is pretty extreme but it gets the point across:

import jaxley as jx
import time
from jaxley.channels import Na
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

cell = jx.Cell()
cell.insert(Na())

sim_time_list, array_time_list = list(), list()

net = jx.Network([cell for _ in range(100)])
fully_connect(net, net, IonotropicSynapse())

params = net.get_parameters()

# Small number of recordings (4)
net.delete_recordings()
net.cell(range(2)).record('i_IonotropicSynapse')

start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time)

# Huge number of recordings (10,000)
net.delete_recordings()
net.record('i_IonotropicSynapse')

start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time)

@ntolley
Copy link
Contributor Author

ntolley commented Dec 20, 2024

And finally here's my timing results for the two different branches. There's still a slight slowdown with more recordings, but it's an order of magnitude faster when recording from 10,000 synapses so I'd say it's an improvement 😄

Here's the results for a 10 ms simulation

main record_speedup
4 recordings 0.51 s 0.50 s
10,000 recordings 10.9 s 1.15 s

@ntolley ntolley changed the title WIP: Vectorize recording during integrate [MRG] Vectorize recording during integrate Dec 20, 2024
@ntolley
Copy link
Contributor Author

ntolley commented Dec 20, 2024

@michaeldeistler @jnsbck since this is mainly a performance boost I'm not sure how it should be tested. I feel like testing the execution time directly could be very brittle for running tests locally. Unless you have some ideas, perhaps it isn't necessary?

@ntolley
Copy link
Contributor Author

ntolley commented Dec 20, 2024

Unfortunately that performance hit does scale with time, here's the results for a 100 ms simulation

main record_speedup
4 recordings 0.55 s 0.63 s
10,000 recordings 16.33 s 6.85 s

So there's still some optimizations to be made...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant