Skip to content

Commit

Permalink
Remove encoder from FeatureDecompositionTrajectorySampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Sep 10, 2024
1 parent a86cab5 commit 9be4510
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 16 deletions.
9 changes: 0 additions & 9 deletions trieste/models/gpflow/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from gpflux.math import compute_A_inv_b
from typing_extensions import Protocol, TypeGuard, runtime_checkable

from ...space import EncoderFunction
from ...types import TensorType
from ...utils import DEFAULTS, flatten_leading_dims
from ..interfaces import (
Expand All @@ -44,7 +43,6 @@
TrajectoryFunction,
TrajectoryFunctionClass,
TrajectorySampler,
get_encoder,
)

_IntTensorType = Union[tf.Tensor, int]
Expand Down Expand Up @@ -399,7 +397,6 @@ def get_trajectory(self) -> TrajectoryFunction:
feature_functions=self._feature_functions,
weight_sampler=weight_sampler,
mean_function=self._mean_function,
encoder=get_encoder(self._model),
)

def update_trajectory(self, trajectory: TrajectoryFunction) -> TrajectoryFunction:
Expand Down Expand Up @@ -876,18 +873,15 @@ def __init__(
feature_functions: Callable[[TensorType], TensorType],
weight_sampler: Callable[[int], TensorType],
mean_function: Callable[[TensorType], TensorType],
encoder: EncoderFunction | None = None,
):
"""
:param feature_functions: Set of feature function.
:param weight_sampler: New sampler that generates feature weight samples.
:param mean_function: The underlying model's mean function.
:param encoder: Optional encoder with which to transform input points.
"""
self._feature_functions = feature_functions
self._mean_function = mean_function
self._weight_sampler = weight_sampler
self._encoder = encoder
self._initialized = tf.Variable(False)

self._weights_sample = tf.Variable( # dummy init to be updated before trajectory evaluation
Expand All @@ -902,9 +896,6 @@ def __init__(
def __call__(self, inputs: TensorType) -> TensorType: # [N, B, D] -> [N, B, L]
"""Call trajectory function."""

if self._encoder is not None:
inputs = self._encoder(inputs)

if not self._initialized: # work out desired batch size from input
self._batch_size.assign(tf.shape(inputs)[-2]) # B
self.resample() # sample B feature weights
Expand Down
7 changes: 0 additions & 7 deletions trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,3 @@ def conditional_predict_y(
return self.conditional_predict_y_encoded(
self.encode(query_points), self.encode(additional_data)
)


def get_encoder(model: ProbabilisticModel) -> EncoderFunction | None:
"""Helper function for getting an encoder from model (which may or may not have one)."""
if isinstance(model, EncodedProbabilisticModel):
return model.encoder
return None

0 comments on commit 9be4510

Please sign in to comment.