Skip to content

Commit

Permalink
Copy dataset in optimizers to avoid changing it
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Nov 24, 2023
1 parent 25da01b commit 292faaa
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
3 changes: 1 addition & 2 deletions tests/unit/test_ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import copy
from typing import Mapping, Optional

import numpy.testing as npt
Expand Down Expand Up @@ -503,7 +502,7 @@ def update(self, dataset: Dataset) -> None:

observer = mk_batch_observer(lambda x: Dataset(x, x))
rule = FixedAcquisitionRule(query_points)
ask_tell = AskTellOptimizer(search_space, copy.deepcopy(init_data), models, rule)
ask_tell = AskTellOptimizer(search_space, init_data, models, rule)

points = ask_tell.ask()
new_data = observer(points)
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import copy
import tempfile
from collections.abc import Mapping
from pathlib import Path
Expand Down Expand Up @@ -303,7 +302,7 @@ def update(self, dataset: Dataset) -> None:

optimizer = BayesianOptimizer(lambda x: Dataset(x, x), search_space)
rule = FixedAcquisitionRule(query_points)
optimizer.optimize(1, copy.deepcopy(init_data), models, rule).final_result.unwrap()
optimizer.optimize(1, init_data, models, rule).final_result.unwrap()


@pytest.mark.parametrize("mode", ["early", "fail", "full"])
Expand Down
3 changes: 3 additions & 0 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def __init__(
if not datasets or not models:
raise ValueError("dicts of datasets and models must be populated.")

# Copy the dataset so we don't change the one provided by the user.
datasets = deepcopy(datasets)

if isinstance(datasets, Dataset):
datasets = {OBJECTIVE: datasets}
if not isinstance(models, Mapping):
Expand Down
3 changes: 3 additions & 0 deletions trieste/bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ def optimize(
- ``datasets`` or ``models`` are empty
- the default `acquisition_rule` is used and the tags are not `OBJECTIVE`.
"""
# Copy the dataset so we don't change the one provided by the user.
datasets = copy.deepcopy(datasets)

if isinstance(datasets, Dataset):
datasets = {OBJECTIVE: datasets}
if not isinstance(models, Mapping):
Expand Down

0 comments on commit 292faaa

Please sign in to comment.