diff --git a/HISTORY.md b/HISTORY.md index e246a29ea..0a29d4b5f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -6,6 +6,8 @@ #### API +- **Backwards-incompatible:** BanditScheduler: Add emitter_pool and active attr; + remove emitters attr ({pr}`494`) - Add DensityRanker for density descent search ({pr}`483`) - Add NoveltyRanker for novelty search ({pr}`477`) - Add proximity_archive_plot for visualizing ProximityArchive ({pr}`476`, diff --git a/ribs/schedulers/_bandit_scheduler.py b/ribs/schedulers/_bandit_scheduler.py index 51ba855b3..d2e938753 100644 --- a/ribs/schedulers/_bandit_scheduler.py +++ b/ribs/schedulers/_bandit_scheduler.py @@ -4,6 +4,7 @@ import numpy as np +from ribs._utils import readonly from ribs.schedulers._scheduler import Scheduler @@ -172,10 +173,16 @@ def archive(self): return self._archive @property - def emitters(self): - """list of ribs.archives.EmitterBase: Emitters for generating solutions - in this scheduler.""" - return self._active_arr + def emitter_pool(self): + """list of ribs.archives.EmitterBase: The pool of emitters available in + the scheduler.""" + return self._emitter_pool + + @property + def active(self): + """numpy.ndarray: Boolean array indicating which emitters in the + :attr:`emitter_pool` are currently active.""" + return readonly(self._active_arr.view()) @property def result_archive(self): diff --git a/tests/schedulers/scheduler_test.py b/tests/schedulers/scheduler_test.py index 5e190756f..29cf78415 100644 --- a/tests/schedulers/scheduler_test.py +++ b/tests/schedulers/scheduler_test.py @@ -34,6 +34,29 @@ def add_mode(request): return request.param +@pytest.mark.parametrize("scheduler_type", ["Scheduler", "BanditScheduler"]) +def test_attributes(scheduler_type): + archive = GridArchive(solution_dim=2, + dims=[100, 100], + ranges=[(-1, 1), (-1, 1)], + threshold_min=1.0, + learning_rate=1.0) + emitters = [GaussianEmitter(archive, sigma=1, x0=[0.0, 0.0], batch_size=4)] + + if scheduler_type == "Scheduler": + scheduler = Scheduler(archive, emitters) + + assert scheduler.archive == archive + assert scheduler.emitters == emitters + else: + scheduler = BanditScheduler(archive, emitters, 1) + + assert scheduler.archive == archive + assert scheduler.emitter_pool == emitters + assert len(scheduler.active) == len(scheduler.emitter_pool) + assert not np.any(scheduler.active) + + def test_init_fails_with_non_list(): archive = GridArchive(solution_dim=2, dims=[100, 100], @@ -361,11 +384,11 @@ def test_constant_active_emitters_bandit_scheduler(): for _ in range(num_loops): solutions = scheduler.ask() - assert scheduler.emitters.sum() == expected_active + assert scheduler.active.sum() == expected_active # Mock objective and measures for tell objective = rng.random(len(solutions)) measures = rng.random((len(solutions), 2)) scheduler.tell(objective, measures) - assert scheduler.emitters.sum() == expected_active + assert scheduler.active.sum() == expected_active