Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AskTell support for indices to differently sized datasets #884

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions tests/unit/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_ask_tell_optimizer_returns_complete_state(
assert_datasets_allclose(state.record.dataset, init_dataset)
assert isinstance(state.record.model, type(model))
assert state.record.acquisition_state is None
assert state.local_data_ixs is not None
assert isinstance(state.local_data_ixs, Sequence)
assert state.local_data_len == 2
npt.assert_array_equal(
state.local_data_ixs,
Expand Down Expand Up @@ -229,8 +229,8 @@ def test_ask_tell_optimizer_loads_from_state(

assert_datasets_allclose(new_state.record.dataset, old_state.record.dataset)
assert old_state.record.model is new_state.record.model
assert new_state.local_data_ixs is not None
assert old_state.local_data_ixs is not None
assert isinstance(new_state.local_data_ixs, Sequence)
assert isinstance(old_state.local_data_ixs, Sequence)
npt.assert_array_equal(new_state.local_data_ixs, old_state.local_data_ixs)
assert old_state.local_data_len == new_state.local_data_len == len(init_dataset.query_points)

Expand Down Expand Up @@ -948,15 +948,13 @@ def test_ask_tell_optimizer_dataset_len_variables(
assert AskTellOptimizer.dataset_len({"tag1": dataset, "tag2": dataset}) == 2


def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets(
def test_ask_tell_optimizer_dataset_len_returns_dict_on_inconsistently_sized_datasets(
init_dataset: Dataset,
) -> None:
with pytest.raises(ValueError):
AskTellOptimizer.dataset_len(
{"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))}
)
with pytest.raises(ValueError):
AskTellOptimizer.dataset_len({})
assert AskTellOptimizer.dataset_len(
{"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))}
) == {"tag": 2, "empty": 0}
assert AskTellOptimizer.dataset_len({}) == {}


@pytest.mark.parametrize("optimizer", OPTIMIZERS)
Expand Down
25 changes: 17 additions & 8 deletions trieste/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import copy
import functools
from typing import Dict, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -162,7 +164,9 @@ def copy_to_local_models(
def with_local_datasets(
datasets: Mapping[Tag, Dataset],
num_local_datasets: int,
local_dataset_indices: Optional[Sequence[TensorType]] = None,
local_dataset_indices: Optional[
Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]
] = None,
) -> Dict[Tag, Dataset]:
"""
Helper method to add local datasets if they do not already exist, by copying global datasets
Expand All @@ -174,17 +178,22 @@ def with_local_datasets(
the global datasets should be copied. If None then the entire datasets are copied.
:return: The updated mapping of datasets.
"""
if local_dataset_indices is not None and len(local_dataset_indices) != num_local_datasets:
raise ValueError(
f"local_dataset_indices should have {num_local_datasets} entries, "
f"has {len(local_dataset_indices)}"
)
if isinstance(local_dataset_indices, Sequence):
local_dataset_indices = {tag: local_dataset_indices for tag in datasets}

updated_datasets = {}
for tag in datasets:
updated_datasets[tag] = datasets[tag]
ltag = LocalizedTag.from_tag(tag)
if not ltag.is_local:
if local_dataset_indices is not None:
if tag not in local_dataset_indices:
raise ValueError(f"local_dataset_indices missing tag {tag}")
elif len(local_dataset_indices[tag]) != num_local_datasets:
raise ValueError(
f"local_dataset_indices for tag {tag} should have {num_local_datasets} "
f"entries, but has {len(local_dataset_indices[tag])}"
)
for i in range(num_local_datasets):
target_ltag = LocalizedTag(ltag.global_tag, i)
if target_ltag not in datasets:
Expand All @@ -194,10 +203,10 @@ def with_local_datasets(
# TODO: use sparse tensors instead
updated_datasets[target_ltag] = Dataset(
query_points=tf.gather(
datasets[tag].query_points, local_dataset_indices[i]
datasets[tag].query_points, local_dataset_indices[tag][i]
),
observations=tf.gather(
datasets[tag].observations, local_dataset_indices[i]
datasets[tag].observations, local_dataset_indices[tag][i]
),
)

Expand Down
129 changes: 92 additions & 37 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class AskTellOptimizerState(Generic[StateType, ProbabilisticModelType]):
record: Record[StateType, ProbabilisticModelType]
""" A record of the current state of the optimization. """

local_data_ixs: Optional[Sequence[TensorType]]
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]]
""" Indices to the local data, for LocalDatasetsAcquisitionRule rules
when `track_data` is `False`. """

Expand All @@ -108,7 +108,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -122,7 +122,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -139,7 +139,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -152,7 +152,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -166,7 +166,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -183,7 +183,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
): ...

Expand All @@ -204,7 +204,7 @@ def __init__(
*,
fit_model: bool = True,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
):
"""
Expand All @@ -225,9 +225,12 @@ def __init__(
updates to the global datasets (optionally using `local_data_ixs` and indices passed
in to `tell`).
:param local_data_ixs: Indices to the local data in the initial datasets. If unspecified,
assumes that the initial datasets are global.
assumes that the initial datasets are global. Can a be a single sequence for all
datasets, or a mapping with separate values for each dataset.
:param local_data_len: Optional length of the data when the passed in `local_data_ixs`
were measured. If the data has increased since then, the indices are extended.
(Note that this is only supported when all datasets have the same length. If not,
then it is up to the caller to update the indices before initialization.)
:raise ValueError: If any of the following are true:
- the keys in ``datasets`` and ``models`` do not match
- ``datasets`` or ``models`` are empty
Expand Down Expand Up @@ -287,12 +290,41 @@ def __init__(
if self.track_data:
datasets = self._datasets = with_local_datasets(self._datasets, num_local_datasets)
else:
self._dataset_len = self.dataset_len(self._datasets)
if local_data_ixs is not None:
dataset_len = self.dataset_len(self._datasets)
self._dataset_len = dataset_len if isinstance(dataset_len, int) else None
self._dataset_ixs: list[TensorType] | Mapping[Tag, list[TensorType]]

if local_data_ixs is None:
# assume that the initial datasets are global
if isinstance(dataset_len, int):
self._dataset_ixs = [
tf.range(dataset_len) for _ in range(num_local_datasets)
]
else:
self._dataset_ixs = {
t: [tf.range(l) for _ in range(num_local_datasets)]
for t, l in dataset_len.items()
}

elif isinstance(local_data_ixs, Mapping):
self._dataset_ixs = {t: list(ixs) for t, ixs in local_data_ixs.items()}
if local_data_len is not None:
raise ValueError(
"Cannot infer new data points for datasets with different "
"local data indices. Pass in full indices instead."
)

else:
self._dataset_ixs = list(local_data_ixs)

if local_data_len is not None:
# infer new dataset indices from change in dataset sizes
num_new_points = self._dataset_len - local_data_len
if isinstance(dataset_len, Mapping):
raise ValueError(
"Cannot infer new data points for datasets with different "
"lengths. Pass in full indices instead."
)
num_new_points = dataset_len - local_data_len
if num_new_points < 0 or (
num_local_datasets > 0 and num_new_points % num_local_datasets != 0
):
Expand All @@ -310,10 +342,6 @@ def __init__(
],
-1,
)
else:
self._dataset_ixs = [
tf.range(self._dataset_len) for _ in range(num_local_datasets)
]

datasets = with_local_datasets(
self._datasets, num_local_datasets, self._dataset_ixs
Expand Down Expand Up @@ -375,7 +403,7 @@ def dataset(self) -> Dataset:
raise ValueError(f"Expected a single dataset, found {len(datasets)}")

@property
def local_data_ixs(self) -> Optional[Sequence[TensorType]]:
def local_data_ixs(self) -> Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]]:
"""Indices to the local data. Only stored for LocalDatasetsAcquisitionRule rules
when `track_data` is `False`."""
if isinstance(self._acquisition_rule, LocalDatasetsAcquisitionRule) and not self.track_data:
Expand Down Expand Up @@ -433,8 +461,8 @@ def acquisition_state(self) -> StateType | None:
return self._acquisition_state

@classmethod
def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
"""Helper method for inferring the global dataset size."""
def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int | Mapping[Tag, int]:
"""Helper method for inferring the global dataset size(s)."""
dataset_lens = {
tag: int(tf.shape(dataset.query_points)[0])
for tag, dataset in datasets.items()
Expand All @@ -444,9 +472,7 @@ def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
if len(unique_lens) == 1:
return int(unique_lens[0])
else:
raise ValueError(
f"Expected unique global dataset size, got {unique_lens}: {dataset_lens}"
)
return dataset_lens

@classmethod
def from_record(
Expand All @@ -465,7 +491,7 @@ def from_record(
| None
) = None,
track_data: bool = True,
local_data_ixs: Optional[Sequence[TensorType]] = None,
local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
local_data_len: Optional[int] = None,
) -> AskTellOptimizerType:
"""Creates new :class:`~AskTellOptimizer` instance from provided optimization state.
Expand Down Expand Up @@ -634,14 +660,15 @@ def ask(self) -> TensorType:
def tell(
self,
new_data: Mapping[Tag, Dataset] | Dataset,
new_data_ixs: Optional[Sequence[TensorType]] = None,
new_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None,
) -> None:
"""Updates optimizer state with new data.

:param new_data: New observed data. If `track_data` is `False`, this refers to all
the data.
:param new_data_ixs: Indices to the new observed local data, if `track_data` is `False`.
If unspecified, inferred from the change in dataset sizes.
If unspecified, inferred from the change in dataset sizes (as long as all the
datasets have the same size).
:raise ValueError: If keys in ``new_data`` do not match those in already built dataset.
"""
if isinstance(new_data, Dataset):
Expand Down Expand Up @@ -670,10 +697,45 @@ def tell(
elif not isinstance(self._acquisition_rule, LocalDatasetsAcquisitionRule):
datasets = new_data
else:
num_local_datasets = len(self._dataset_ixs)
if new_data_ixs is None:
num_local_datasets = (
len(self._dataset_ixs)
if isinstance(self._dataset_ixs, Sequence)
else len(next(iter(self._dataset_ixs.values())))
)

if new_data_ixs is not None:
# use explicit indices
def update_ixs(ixs: list[TensorType], new_ixs: Sequence[TensorType]) -> None:
if len(ixs) != len(new_ixs):
raise ValueError(
f"new_data_ixs has {len(new_ixs)} entries, expected {len(ixs)}"
)
for i in range(len(ixs)):
ixs[i] = tf.concat([ixs[i], new_ixs[i]], -1)

if isinstance(new_data_ixs, Sequence) and isinstance(self._dataset_ixs, Mapping):
raise ValueError("separate new_data_ixs required for each dataset")
if isinstance(new_data_ixs, Mapping) and isinstance(self._dataset_ixs, Sequence):
self._dataset_ixs = {tag: list(self._dataset_ixs) for tag in self._datasets}
if isinstance(new_data_ixs, Mapping):
assert isinstance(self._dataset_ixs, Mapping)
for tag in self._datasets:
update_ixs(self._dataset_ixs[tag], new_data_ixs[tag])
else:
assert isinstance(self._dataset_ixs, list)
update_ixs(self._dataset_ixs, new_data_ixs)

else:
# infer dataset indices from change in dataset sizes
if isinstance(self._dataset_ixs, Mapping) or not isinstance(self._dataset_len, int):
raise NotImplementedError(
"new data indices cannot be inferred for datasets with different sizes"
)
new_dataset_len = self.dataset_len(new_data)
if not isinstance(new_dataset_len, int):
raise NotImplementedError(
"new data indices cannot be inferred for new data with different sizes"
)
num_new_points = new_dataset_len - self._dataset_len
if num_new_points < 0 or (
num_local_datasets > 0 and num_new_points % num_local_datasets != 0
Expand All @@ -690,17 +752,10 @@ def tell(
],
-1,
)
else:
# use explicit indices
if len(new_data_ixs) != num_local_datasets:
raise ValueError(
f"new_data_ixs has {len(new_data_ixs)} entries, "
f"expected {num_local_datasets}"
)
for i in range(num_local_datasets):
self._dataset_ixs[i] = tf.concat([self._dataset_ixs[i], new_data_ixs[i]], -1)

datasets = with_local_datasets(new_data, num_local_datasets, self._dataset_ixs)
self._dataset_len = self.dataset_len(datasets)
dataset_len = self.dataset_len(datasets)
self._dataset_len = dataset_len if isinstance(dataset_len, int) else None

filtered_datasets = self._acquisition_rule.filter_datasets(self._models, datasets)
if callable(filtered_datasets):
Expand Down
Loading