Skip to content

Commit

Permalink
Fix min window type check (#523)
Browse files Browse the repository at this point in the history
* fix: replace dict with Mapping

* fix: replace list with Sequence

* fix: add type hint

* fix: does not accept None
  • Loading branch information
phi-friday authored Oct 1, 2024
1 parent f63372e commit c48e566
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions bayes_opt/domain_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from warnings import warn

Expand All @@ -16,8 +17,6 @@
from bayes_opt.target_space import TargetSpace

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping, Sequence

from numpy.typing import NDArray

Float = np.floating[Any]
Expand Down Expand Up @@ -66,12 +65,14 @@ def __init__(
gamma_osc: float = 0.7,
gamma_pan: float = 1.0,
eta: float = 0.9,
minimum_window: NDArray[Float] | Sequence[float] | float | Mapping[str, float] | None = 0.0,
minimum_window: NDArray[Float] | Sequence[float] | Mapping[str, float] | float = 0.0,
) -> None:
self.gamma_osc = gamma_osc
self.gamma_pan = gamma_pan
self.eta = eta
if isinstance(minimum_window, dict):

self.minimum_window_value: NDArray[Float] | Sequence[float] | float
if isinstance(minimum_window, Mapping):
self.minimum_window_value = [
item[1] for item in sorted(minimum_window.items(), key=lambda x: x[0])
]
Expand All @@ -90,8 +91,9 @@ def initialize(self, target_space: TargetSpace) -> None:
self.original_bounds = np.copy(target_space.bounds)
self.bounds = [self.original_bounds]

self.minimum_window: NDArray[Float] | Sequence[float]
# Set the minimum window to an array of length bounds
if isinstance(self.minimum_window_value, (list, np.ndarray)):
if isinstance(self.minimum_window_value, (Sequence, np.ndarray)):
if len(self.minimum_window_value) != len(target_space.bounds):
error_msg = "Length of minimum_window must be the same as the number of parameters"
raise ValueError(error_msg)
Expand Down

0 comments on commit c48e566

Please sign in to comment.