-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A global state for Jaxley modules #476
Comments
Even better would be if the net = jx.Network(...)
net.make_trainable("radius")
# First, we can run `net.set()` between jx.integrate calls and it
# will not recompile.
@jax.jit
def simulate(net):
return jx.integrate(net)
# First run, needs compilation.
v1 = simulate(net)
# Now, modify the module, but we need no re-compilation!
net.set("HH_gNa", 0.2)
v2 = simulate(net)
# Second, there is no more need for `.data_set()` or `.data_stimulate()`!
# We can just `.set()` or `.stimulate()`.
def modified_loss(net, value):
net.set("HH_gK", value)
return loss_fn(net)
gradient_fn = grad(modified_loss, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)
# Importing the functions from a Python module also becomes much easier.
# This did not work previously because the `loss_fn` would rely on a
# net being in global scope.
# E.g., in `myfile.py`:
def loss_fn(net):
return jnp.sum(jx.integrate(net))
# ...and in the jupyter notebook:
from myfile import loss_fn
loss = loss_fn(net)
# We also support jx.rebuild(others) if the `net` itself is still to be
# modified within the Python module. E.g., in `myfile.py`:
def modified_loss_in_module(net, value):
net.set("HH_gK", value)
return loss_fn(net)
# ...and the following in a jupyter notebook.
from myfile import modified_loss_in_module
gradient_fn = grad(modified_loss_in_module, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)
# Finally, following inox (and more reminiscent of the current API), one
# can also split (or partition) the `net`:
static, params, others = model.partition(nn.Parameter)
def loss_fn(params, others):
net = static(params, others)
return jnp.sum(jx.integrate(net))
# If one wanted to change parameters within the loss function in this
# interface, one would do:
static, params, others = model.partition(nn.Parameter)
def loss_fn(params, others, values):
net = static(params, others)
net.set("radius", value)
return jnp.sum(jx.integrate(net)) |
For this, we should consider relying on A small benefit of relying on |
Nice! I like the idea! |
Just playing around with this idea for now:
The text was updated successfully, but these errors were encountered: