Replies: 1 comment
-
Hi Kirk, thanks for reaching out! I modified your code a bit, and the following works for me: import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
import jaxley as jx
cell = jx.Cell()
cell.stimulate(jx.step_current(1.0, 1.0, 0.1, 0.025, 5.0))
cell.record()
cell.make_trainable("radius")
params = cell.get_parameters()
# Pre-compute the locations of the grid on which the KDE will be
# evaluated.
observed_v = jx.integrate(cell, params=params)
observed_dv = (observed_v[:,2:] - observed_v[:,:-2]) / 2
xmin = jnp.squeeze(jnp.min(observed_v, axis=1))
xmax = jnp.squeeze(jnp.max(observed_v, axis=1))
ymin = jnp.squeeze(jnp.min(observed_dv, axis=1))
ymax = jnp.squeeze(jnp.max(observed_dv, axis=1))
X, Y = jnp.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = jnp.vstack([X.ravel(), Y.ravel()])
def summary_stats(v):
dv = (v[:,2:] - v[:,:-2])/2
v_dv = jnp.vstack([v[:,1:-1], dv])
kernel = jax.scipy.stats.gaussian_kde(v_dv)
return jnp.asarray(kernel(positions))
observed_ss = summary_stats(observed_v)
# Modify the parameters such that the loss is not 0.0 (which is boring).
params[0]["radius"] = params[0]["radius"].at[0].set(2.0)
def simulate(params):
return jx.integrate(cell, params=params)
def loss_from_v(v):
ss = summary_stats(v)
return jnp.sum(jnp.sqrt(jnp.abs((ss - observed_ss))))
def loss_fn(params):
v = simulate(params)
return loss_from_v(v)
gradient_fn = jit(value_and_grad(loss_fn))
gradient_fn(params) What did I change?Methods like I hope this helps! Let me know if you have more questions or if things are not behaving as expected, user feedback is super important for us! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to test jaxley with a Phase Plane Density loss function but am getting JAX tracer array Concretisation errors. I thought I'd followed the requirements of JAX, and everything is using jnp specific functions. The initial run through to produce target trace works ok, and subsequent (looped) runs work if I remove JIT wrapper from this function, but then it's very slow.
PHASE PLANE DENSITY LOSS FUNCTION: replaces summary_stats(), and loss_from_v() in 01 synthetic (see below for full code to reproduce)
def summary_stats(v):
dv = (v[:,2:] - v[:,:-2])/2
xmin = jnp.squeeze(jnp.min(v, axis=1))
xmax = jnp.squeeze(jnp.max(v, axis=1))
ymin = jnp.squeeze(jnp.min(dv, axis=1))
ymax = jnp.squeeze(jnp.max(dv, axis=1))
X, Y = jnp.mgrid[xmin:xmax:100j, ymin:ymax:100j] #DEBUG: this line causes Concretisation error
positions = jnp.vstack([X.ravel(), Y.ravel()])
v_dv = jnp.vstack([jnp.asarray(v[:,1:-1]),jnp.asarray(dv)])
kernel = jax.scipy.stats.gaussian_kde(v_dv)
return jnp.asarray(kernel(positions))
def loss_from_v(v):
ss = summary_stats(v)
return jnp.sum(jnp.sqrt(jnp.abs((ss - x_o_ss))), axis=0)
FULL CODE TO REPRODUCE (just 01 synthetic from original jaxley paper with summary_stats() and loss_from_v() replaced by this code). Unrelated code is commented out, but retained for clarity.
01_synthetic(testing with phase plane density loss).txt
Beta Was this translation helpful? Give feedback.
All reactions