Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A global state for Jaxley modules #476

Open
michaeldeistler opened this issue Oct 29, 2024 · 3 comments
Open

A global state for Jaxley modules #476

michaeldeistler opened this issue Oct 29, 2024 · 3 comments
Labels

Comments

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Oct 29, 2024

Just playing around with this idea for now:

net = jx.Network(...)
net.make_trainable("radius")


# First, we can run `net.set()` between jx.integrate calls and it
# will not recompile.
@jax.jit
def simulate(trainables, others):
    return jx.integrate(trainables, others)

# First run, needs compilation.
v1 = simulate(net.trainables, net.others)

# Now, modify the module, but we need no re-compilation!
net.set("HH_gNa", 0.2)
v2 = simulate(net.trainables, net.others)


# Second, there is no more need for `.data_set()` or `.data_stimulate()`!
# We can just `.set()` or `.stimulate()`.
def modified_loss(value, trainables, others):
    net.set("HH_gK", value)
    return loss_fn(trainables, net.others)

gradient_fn = grad(modified_loss, argnums=(0, 1))
grad_val = gradient_fn(2.0, net.trainables, net.others)


# Importing the functions from a Python module also becomes much easier.
# This did not work previously because the `loss_fn` would rely on a
# net being in global scope.
# E.g., in `myfile.py`:
def loss_fn(trainables, others):
    return jnp.sum(jx.integrate(trainables, others))

# ...and in the jupyter notebook:
from myfile import loss_fn
loss = loss_fn(net.trainables, net.others)


# We also support jx.rebuild(others) if the `net` itself is still to be
# modified within the Python module. E.g., in `myfile.py`:
def modified_loss_in_module(value, trainables, others):
    net = jx.rebuild(others)  # Can also be achieved by reading a pickled net.
    net.set("HH_gK", value)
    return loss_fn(trainables, net.others)

# ...and the following in a jupyter notebook.
from myfile import modified_loss_in_module
gradient_fn = grad(modified_loss_in_module, argnums=(0, 1))
grad_val = gradient_fn(2.0, net.trainables, net.others)
@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented Oct 30, 2024

Even better would be if the net itself were a pytree that can be passed around:

net = jx.Network(...)
net.make_trainable("radius")


# First, we can run `net.set()` between jx.integrate calls and it
# will not recompile.
@jax.jit
def simulate(net):
    return jx.integrate(net)

# First run, needs compilation.
v1 = simulate(net)

# Now, modify the module, but we need no re-compilation!
net.set("HH_gNa", 0.2)
v2 = simulate(net)


# Second, there is no more need for `.data_set()` or `.data_stimulate()`!
# We can just `.set()` or `.stimulate()`.
def modified_loss(net, value):
    net.set("HH_gK", value)
    return loss_fn(net)

gradient_fn = grad(modified_loss, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)


# Importing the functions from a Python module also becomes much easier.
# This did not work previously because the `loss_fn` would rely on a
# net being in global scope.
# E.g., in `myfile.py`:
def loss_fn(net):
    return jnp.sum(jx.integrate(net))

# ...and in the jupyter notebook:
from myfile import loss_fn
loss = loss_fn(net)


# We also support jx.rebuild(others) if the `net` itself is still to be
# modified within the Python module. E.g., in `myfile.py`:
def modified_loss_in_module(net, value):
    net.set("HH_gK", value)
    return loss_fn(net)

# ...and the following in a jupyter notebook.
from myfile import modified_loss_in_module
gradient_fn = grad(modified_loss_in_module, argnums=(0, 1))
grad_val = gradient_fn(net, 2.0)


# Finally, following inox (and more reminiscent of the current API), one
# can also split (or partition) the `net`:
static, params, others = model.partition(nn.Parameter)

def loss_fn(params, others):
    net = static(params, others)
    return jnp.sum(jx.integrate(net))


# If one wanted to change parameters within the loss function in this
# interface, one would do:
static, params, others = model.partition(nn.Parameter)

def loss_fn(params, others, values):
    net = static(params, others)
    net.set("radius", value)
    return jnp.sum(jx.integrate(net))

@michaeldeistler
Copy link
Contributor Author

michaeldeistler commented Oct 30, 2024

For this, we should consider relying on inox or on flax nnx. IMO they are, at this point, quite similar in their API. Obvisouly flax nnx is of course much larger and will surely be maintained, whereas inox is minimal and could allow us to actually get into the weeds.

A small benefit of relying on flax nnx would be that users of flax nnx will already be familiar with part of the API of Jaxley.

@jnsbck
Copy link
Contributor

jnsbck commented Nov 4, 2024

Nice! I like the idea!
Lemme know if you want to brainstorm specifics! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants