Skip to content

Commit

Permalink
Encoded DeepGPs.
Browse files Browse the repository at this point in the history
  • Loading branch information
avullo committed Sep 17, 2024
1 parent 1dd0857 commit e90f91c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
24 changes: 18 additions & 6 deletions trieste/models/gpflux/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,41 @@
from gpflow.base import Module
from gpflow.keras import tf_keras

from ...space import EncoderFunction
from ...types import TensorType
from ..interfaces import SupportsGetObservationNoise, SupportsPredictY
from ..interfaces import EncodedSupportsPredictY, SupportsGetObservationNoise
from ..optimizer import KerasOptimizer


class GPfluxPredictor(SupportsGetObservationNoise, SupportsPredictY, ABC):
class GPfluxPredictor(SupportsGetObservationNoise, EncodedSupportsPredictY, ABC):
"""
A trainable wrapper for a GPflux deep Gaussian process model. The code assumes subclasses
will use the Keras `fit` method for training, and so they should provide access to both a
`model_keras` and `model_gpflux`.
"""

def __init__(self, optimizer: KerasOptimizer | None = None):
def __init__(self, optimizer: KerasOptimizer | None = None, encoder: EncoderFunction | None = None):
"""
:param optimizer: The optimizer wrapper containing the optimizer with which to train the
model and arguments for the wrapper and the optimizer. The optimizer must
be an instance of a :class:`~tf.optimizers.Optimizer`. Defaults to
:class:`~tf.optimizers.Adam` optimizer with 0.01 learning rate.
:param encoder: Optional encoder with which to transform query points before
generating predictions.
"""
if optimizer is None:
optimizer = KerasOptimizer(tf_keras.optimizers.Adam(0.01))

self._optimizer = optimizer
self._encoder = encoder

@property
def encoder(self) -> EncoderFunction | None:
return self._encoder

@encoder.setter
def encoder(self, encoder: EncoderFunction | None) -> None:
self._encoder = encoder

@property
@abstractmethod
Expand All @@ -60,17 +72,17 @@ def optimizer(self) -> KerasOptimizer:
return self._optimizer

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
def predict_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""Note: unless otherwise noted, this returns the mean and variance of the last layer
conditioned on one sample from the previous layers."""
return self.model_gpflux.predict_f(query_points)

@abstractmethod
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType:
raise NotImplementedError

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
def predict_y_encoded(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""Note: unless otherwise noted, this will return the prediction conditioned on one sample
from the lower layers."""
f_mean, f_var = self.model_gpflux.predict_f(query_points)
Expand Down
16 changes: 10 additions & 6 deletions trieste/models/gpflux/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@

from ... import logging
from ...data import Dataset
from ...space import EncoderFunction
from ...types import TensorType
from ..interfaces import (
EncodedTrainableProbabilisticModel,
HasReparamSampler,
HasTrajectorySampler,
ReparametrizationSampler,
TrainableProbabilisticModel,
TrajectorySampler,
)
from ..optimizer import KerasOptimizer
Expand All @@ -50,7 +51,7 @@


class DeepGaussianProcess(
GPfluxPredictor, TrainableProbabilisticModel, HasReparamSampler, HasTrajectorySampler
GPfluxPredictor, EncodedTrainableProbabilisticModel, HasReparamSampler, HasTrajectorySampler
):
"""
A :class:`TrainableProbabilisticModel` wrapper for a GPflux :class:`~gpflux.models.DeepGP` with
Expand All @@ -65,6 +66,7 @@ def __init__(
num_rff_features: int = 1000,
continuous_optimisation: bool = True,
compile_args: Optional[Mapping[str, Any]] = None,
encoder: EncoderFunction | None = None,
):
"""
:param model: The underlying GPflux deep Gaussian process model. Passing in a named closure
Expand All @@ -88,6 +90,8 @@ def __init__(
See https://keras.io/api/models/model_training_apis/#compile-method for a
list of possible arguments. The ``optimizer`` and ``metrics`` arguments
must not be included.
:param encoder: Optional encoder with which to transform query points before
generating predictions.
:raise ValueError: If ``model`` has unsupported layers, ``num_rff_features`` is less than 0,
if the ``optimizer`` is not of a supported type, or `compile_args` contains
disallowed arguments.
Expand All @@ -113,7 +117,7 @@ def __init__(
f"`LatentVariableLayer`, received {type(layer)} instead."
)

super().__init__(optimizer)
super().__init__(optimizer, encoder)

if num_rff_features <= 0:
raise ValueError(
Expand Down Expand Up @@ -305,7 +309,7 @@ def model_keras(self) -> tf_keras.Model:
return self._model_keras

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
def sample_encoded(self, query_points: TensorType, num_samples: int) -> TensorType:
trajectory = self.trajectory_sampler().get_trajectory()
expanded_query_points = tf.expand_dims(query_points, -2) # [N, 1, D]
tiled_query_points = tf.tile(expanded_query_points, [1, num_samples, 1]) # [N, S, D]
Expand All @@ -329,7 +333,7 @@ def trajectory_sampler(self) -> TrajectorySampler[GPfluxPredictor]:
"""
return DeepGaussianProcessDecoupledTrajectorySampler(self, self._num_rff_features)

def update(self, dataset: Dataset) -> None:
def update_encoded(self, dataset: Dataset) -> None:
inputs = dataset.query_points
new_num_data = inputs.shape[0]
self.model_gpflux.num_data = new_num_data
Expand Down Expand Up @@ -366,7 +370,7 @@ def update(self, dataset: Dataset) -> None:

inputs = layer(inputs)

def optimize(self, dataset: Dataset) -> tf_keras.callbacks.History:
def optimize_encoded(self, dataset: Dataset) -> tf_keras.callbacks.History:
"""
Optimize the model with the specified `dataset`.
:param dataset: The data with which to optimize the `model`.
Expand Down

0 comments on commit e90f91c

Please sign in to comment.