Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add emitter_kwargs to optimizer ask and tell #159

Merged
merged 14 commits into from
Jul 7, 2021
55 changes: 47 additions & 8 deletions ribs/optimizers/_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Provides the Optimizer."""
import itertools

import numpy as np
from threadpoolctl import threadpool_limits

Expand Down Expand Up @@ -81,12 +83,30 @@ def emitters(self):
in this optimizer."""
return self._emitters

def ask(self):
@staticmethod
def _process_emitter_kwargs(emitter_kwargs):
"""Converts emitter_kwargs to an iterable so it can zip with the
emitters."""
if emitter_kwargs is None:
return itertools.repeat({})
if isinstance(emitter_kwargs, dict):
return itertools.repeat(emitter_kwargs)
return emitter_kwargs # Assume it is a list/iterable of dicts.

def ask(self, emitter_kwargs=None):
"""Generates a batch of solutions by calling ask() on all emitters.

.. note:: The order of the solutions returned from this method is
important, so do not rearrange them.

Args:
emitter_kwargs (dict or list of dict): kwargs to pass to the
emitters' :meth:`~ribs.emitters.EmitterBase.ask` method. If one
dict is passed in, its kwargs are passed to all the emitters. If
a list of dicts is passed in, each dict is passed to each
emitter (e.g. ``dict[0]`` goes to :attr:`emitters` [0]).
Emitters are in the same order as they were when the optimizer
was constructed.
Returns:
(n_solutions, dim) array: An array of n solutions to evaluate. Each
row contains a single solution.
Expand All @@ -99,19 +119,25 @@ def ask(self):
self._asked = True

self._solutions = []
emitter_kwargs = self._process_emitter_kwargs(emitter_kwargs)

# Limit OpenBLAS to single thread. This is typically faster than
# multithreading because our data is too small.
with threadpool_limits(limits=1, user_api="blas"):
for i, emitter in enumerate(self._emitters):
emitter_sols = emitter.ask()
for i, (emitter,
kwargs) in enumerate(zip(self._emitters, emitter_kwargs)):
emitter_sols = emitter.ask(**kwargs)
self._solutions.append(emitter_sols)
self._num_emitted[i] = len(emitter_sols)

self._solutions = np.concatenate(self._solutions, axis=0)
return self._solutions

def tell(self, objective_values, behavior_values, metadata=None):
def tell(self,
objective_values,
behavior_values,
metadata=None,
emitter_kwargs=None):
"""Returns info for solutions from :meth:`ask`.

.. note:: The objective values, behavior values, and metadata must be in
Expand All @@ -127,6 +153,13 @@ def tell(self, objective_values, behavior_values, metadata=None):
this array contains a solution's coordinates in behavior space.
metadata ((n_solutions,) array): Each entry of this array contains
an object holding metadata for a solution.
emitter_kwargs (dict or list of dict): kwargs to pass to the
emitters' :meth:`~ribs.emitters.EmitterBase.tell` method. If one
dict is passed in, its kwargs are passed to all the emitters. If
a list of dicts is passed in, each dict is passed to each
emitter (e.g. ``dict[0]`` goes to :attr:`emitters` [0]).
Emitters are in the same order as they were when the optimizer
was constructed.
Raises:
RuntimeError: This method is called without first calling
:meth:`ask`.
Expand All @@ -135,6 +168,7 @@ def tell(self, objective_values, behavior_values, metadata=None):
raise RuntimeError("tell() was called without calling ask().")
self._asked = False

emitter_kwargs = self._process_emitter_kwargs(emitter_kwargs)
objective_values = np.asarray(objective_values)
behavior_values = np.asarray(behavior_values)
metadata = (np.empty(len(self._solutions), dtype=object)
Expand All @@ -145,9 +179,14 @@ def tell(self, objective_values, behavior_values, metadata=None):
with threadpool_limits(limits=1, user_api="blas"):
# Keep track of pos because emitters may have different batch sizes.
pos = 0
for emitter, n in zip(self._emitters, self._num_emitted):
for emitter, n, kwargs in zip(self._emitters, self._num_emitted,
emitter_kwargs):
end = pos + n
emitter.tell(self._solutions[pos:end],
objective_values[pos:end],
behavior_values[pos:end], metadata[pos:end])
emitter.tell(
self._solutions[pos:end],
objective_values[pos:end],
behavior_values[pos:end],
metadata[pos:end],
**kwargs,
)
pos = end
89 changes: 88 additions & 1 deletion tests/optimizers/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from ribs.archives import GridArchive
from ribs.emitters import GaussianEmitter
from ribs.emitters import EmitterBase, GaussianEmitter
from ribs.optimizers import Optimizer

# pylint: disable = redefined-outer-name
Expand Down Expand Up @@ -126,3 +126,90 @@ def test_tell_fails_when_ask_not_called(optimizer_fixture):
optimizer, *_ = optimizer_fixture
with pytest.raises(RuntimeError):
optimizer.tell(None, None)


@pytest.fixture
def kwargs_fixture():
"""Fixture for testing emitter_kwargs in the optimizer."""

class KwargsEmitter(EmitterBase):
"""Emitter which takes in kwargs in its ask() and tell() methods.

ask() and tell() simply set self.arg to be the value of arg.
"""

def __init__(self, archive):
EmitterBase.__init__(self, archive, 3, None)
self.arg = None

def ask(self, arg=None):
self.arg = arg
return []

def tell(self,
solutions,
objective_values,
behavior_values,
metadata=None,
arg=None):
self.arg = arg

archive = GridArchive([100, 100], [(-1, 1), (-1, 1)])
emitters = [KwargsEmitter(archive) for _ in range(3)]
return emitters, Optimizer(archive, emitters)


def test_ask_with_no_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask(emitter_kwargs=None)
for e in emitters:
assert e.arg is None


def test_ask_with_dict_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask(emitter_kwargs={"arg": 42})
for e in emitters:
assert e.arg == 42


def test_ask_with_list_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask(emitter_kwargs=[{"arg": 1}, {"arg": 2}, {"arg": 3}])
for e, val in zip(emitters, [1, 2, 3]):
assert e.arg == val


def test_tell_with_no_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask()
optimizer.tell([], [], [], emitter_kwargs=None)
for e in emitters:
assert e.arg is None


def test_tell_with_dict_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask()
optimizer.tell([], [], [], emitter_kwargs={"arg": 42})
for e in emitters:
assert e.arg == 42


def test_tell_with_list_emitter_kwargs(kwargs_fixture):
emitters, optimizer = kwargs_fixture
optimizer.ask()
optimizer.tell(
[],
[],
[],
emitter_kwargs=[{
"arg": 1
}, {
"arg": 2
}, {
"arg": 3
}],
)
for e, val in zip(emitters, [1, 2, 3]):
assert e.arg == val