Skip to content

Commit

Permalink
Fix warning due to jax.config
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 6, 2023
1 parent 6574712 commit ebc8489
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 15 deletions.
6 changes: 3 additions & 3 deletions tests/neurax_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


import os
Expand Down
6 changes: 3 additions & 3 deletions tests/neurax_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


import os
Expand Down
6 changes: 3 additions & 3 deletions tests/neurax_vs_neuron/test_comp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


import os
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cell_matches_branch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")


import jax.numpy as jnp
Expand Down
5 changes: 5 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

import neurax as nx
Expand Down
6 changes: 3 additions & 3 deletions tests/test_swc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax.config import config
import jax

config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import os

Expand Down

0 comments on commit ebc8489

Please sign in to comment.