Skip to content

Commit

Permalink
Test GPflux predictor categorical predict.
Browse files Browse the repository at this point in the history
  • Loading branch information
avullo committed Sep 17, 2024
1 parent 9e0ea23 commit 03d0e0f
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tests/unit/models/gpflux/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.util.misc import random_seed
from trieste.data import Dataset
from trieste.models.gpflux import GPfluxPredictor
from trieste.space import CategoricalSearchSpace, EncoderFunction, one_hot_encoder
from trieste.types import TensorType


Expand All @@ -37,8 +38,9 @@ def __init__(
self,
optimizer: tf_keras.optimizers.Optimizer | None = None,
likelihood: gpflow.likelihoods.Likelihood = gpflow.likelihoods.Gaussian(0.01),
encoder: EncoderFunction | None = None,
):
super().__init__(optimizer=optimizer)
super().__init__(optimizer=optimizer, encoder=encoder)

if optimizer is None:
self._optimizer = tf_keras.optimizers.Adam()
Expand Down Expand Up @@ -150,3 +152,14 @@ def test_gpflux_predictor_get_observation_noise_raises_for_non_gaussian_likeliho

with pytest.raises(NotImplementedError):
model.get_observation_noise()


def test_gpflux_categorical_predict() -> None:
search_space = CategoricalSearchSpace(["Red", "Green", "Blue"])
query_points = search_space.sample(10)
model = _QuadraticPredictor(encoder=one_hot_encoder(search_space))
mean, variance = model.predict(query_points)
assert mean.shape == [10, 1]
assert variance.shape == [10, 1]
npt.assert_allclose(mean, [[1.0]] * 10, rtol=0.01)
npt.assert_allclose(variance, [[1.0]] * 10, rtol=0.01)

0 comments on commit 03d0e0f

Please sign in to comment.