Skip to content

Commit

Permalink
Fix ask_tell global dataset_len detection (#853)
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani authored Jun 6, 2024
1 parent 6f2a33a commit f612e90
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions tests/unit/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ def test_ask_tell_optimizer_dataset_len_variables(
dataset = init_dataset

assert AskTellOptimizer.dataset_len({"tag": dataset}) == 2
assert AskTellOptimizer.dataset_len({"tag1": dataset, "tag2": dataset}) == 2


def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets(
Expand Down
4 changes: 2 additions & 2 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,8 @@ def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
]
unique_lens, unique_idxs = tf.unique(dataset_lens)
if len(unique_idxs) == 1:
unique_lens, _ = tf.unique(dataset_lens)
if len(unique_lens) == 1:
return int(unique_lens[0])
else:
raise ValueError(f"Expected unique global dataset size, got {unique_lens}")
Expand Down

0 comments on commit f612e90

Please sign in to comment.