From 75158ad01d046a1c1e44e0d698c912d732865192 Mon Sep 17 00:00:00 2001 From: Sander Vandenhaute Date: Mon, 2 Dec 2024 17:11:32 +0100 Subject: [PATCH] make sure failed walkers / failed single point evaluations still show up as reset in walker table --- psiflow/learning.py | 8 ++++++-- psiflow/utils/apps.py | 8 ++++++++ tests/test_learning.py | 8 +------- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/psiflow/learning.py b/psiflow/learning.py index 40e1e9d..1a5b0b7 100644 --- a/psiflow/learning.py +++ b/psiflow/learning.py @@ -17,7 +17,7 @@ from psiflow.models import Model from psiflow.reference import Reference, evaluate from psiflow.sampling import SimulationOutput, Walker, sample -from psiflow.utils.apps import boolean_or, setup_logger, unpack_i +from psiflow.utils.apps import boolean_or, setup_logger, unpack_i, isnan logger = setup_logger(__name__) @@ -80,7 +80,11 @@ def evaluate_outputs( errors[i], np.array(error_thresholds_for_reset, dtype=float), ) - reset = boolean_or(error_discard, error_reset) + reset = boolean_or( + error_discard, + error_reset, + isnan(errors[i]), + ) _ = assign_identifier(state, identifier, error_discard) assigned = unpack_i(_, 0) diff --git a/psiflow/utils/apps.py b/psiflow/utils/apps.py index 2d40fc7..7bf807d 100644 --- a/psiflow/utils/apps.py +++ b/psiflow/utils/apps.py @@ -134,3 +134,11 @@ def _concatenate(*arrays: np.ndarray) -> np.ndarray: concatenate = python_app(_concatenate, executors=["default_threads"]) + + +@typeguard.typechecked +def _isnan(a: Union[float, np.ndarray]) -> bool: + return bool(np.any(np.isnan(a))) + + +isnan = python_app(_isnan, executors=['default_threads']) diff --git a/tests/test_learning.py b/tests/test_learning.py index ea9254e..b45d53b 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -158,8 +158,6 @@ def test_evaluate_outputs(dataset): outputs[3].state = new_nullstate() outputs[7].status = 2 # should be null state - # resets for 3 and 7 happen in sample() method, not in evaluate_outputs! - identifier = 3 identifier, data, resets = evaluate_outputs( outputs, @@ -185,11 +183,7 @@ def test_evaluate_outputs(dataset): error_thresholds_for_discard=[0.0, 0.0], metrics=Metrics(), ) - for i in range(10): - if i not in [3, 7]: - assert resets[i].result() # already reset - else: - assert not resets[i].result() + assert all([r.result() for r in resets]) def test_wandb():