diff --git a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py index 9c12952a1..6bdcfb4ec 100644 --- a/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py +++ b/vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_utils.py @@ -115,7 +115,7 @@ class EagleStrategyUtils: _n_parameters: int = attr.field(init=False) _degrees_of_freedom: DefaultDict[vz.ParameterType, int] = attr.field( init=False, factory=lambda: collections.defaultdict(int)) - _original_metric_name: str = attr.field(init=False) + _original_metric_name: Optional[str] = attr.field(init=False) _goal: vz.ObjectiveMetricGoal = attr.field(init=False) def __attrs_post_init__(self): @@ -359,6 +359,8 @@ def get_metric(self, trial: vz.Trial) -> float: """Returns the trial metric.""" if trial.infeasible: return np.nan + if trial.final_measurement is None: + raise ValueError('Trial is not completed.') return trial.final_measurement.metrics[OBJECTIVE_NAME] # pytype: disable=bad-return-type def is_better_than( @@ -408,6 +410,8 @@ def standardize_trial_metric_name(self, trial: vz.Trial) -> vz.Trial: """Creates a new trial with canonical metric name.""" if trial.infeasible: return trial + if trial.final_measurement is None: + raise ValueError('Trial is not completed.') value = trial.final_measurement.metrics[self._original_metric_name].value new_trial = vz.Trial(parameters=trial.parameters, metadata=trial.metadata) new_trial.complete(