Skip to content

Commit

Permalink
make sure failed walkers / failed single point evaluations still show…
Browse files Browse the repository at this point in the history
… up as reset in walker table
  • Loading branch information
svandenhaute committed Dec 2, 2024
1 parent ee8d5b5 commit 75158ad
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
8 changes: 6 additions & 2 deletions psiflow/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from psiflow.models import Model
from psiflow.reference import Reference, evaluate
from psiflow.sampling import SimulationOutput, Walker, sample
from psiflow.utils.apps import boolean_or, setup_logger, unpack_i
from psiflow.utils.apps import boolean_or, setup_logger, unpack_i, isnan

logger = setup_logger(__name__)

Expand Down Expand Up @@ -80,7 +80,11 @@ def evaluate_outputs(
errors[i],
np.array(error_thresholds_for_reset, dtype=float),
)
reset = boolean_or(error_discard, error_reset)
reset = boolean_or(
error_discard,
error_reset,
isnan(errors[i]),
)

_ = assign_identifier(state, identifier, error_discard)
assigned = unpack_i(_, 0)
Expand Down
8 changes: 8 additions & 0 deletions psiflow/utils/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,11 @@ def _concatenate(*arrays: np.ndarray) -> np.ndarray:


concatenate = python_app(_concatenate, executors=["default_threads"])


@typeguard.typechecked
def _isnan(a: Union[float, np.ndarray]) -> bool:
return bool(np.any(np.isnan(a)))


isnan = python_app(_isnan, executors=['default_threads'])
8 changes: 1 addition & 7 deletions tests/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def test_evaluate_outputs(dataset):
outputs[3].state = new_nullstate()
outputs[7].status = 2 # should be null state

# resets for 3 and 7 happen in sample() method, not in evaluate_outputs!

identifier = 3
identifier, data, resets = evaluate_outputs(
outputs,
Expand All @@ -185,11 +183,7 @@ def test_evaluate_outputs(dataset):
error_thresholds_for_discard=[0.0, 0.0],
metrics=Metrics(),
)
for i in range(10):
if i not in [3, 7]:
assert resets[i].result() # already reset
else:
assert not resets[i].result()
assert all([r.result() for r in resets])


def test_wandb():
Expand Down

0 comments on commit 75158ad

Please sign in to comment.