Skip to content

Commit

Permalink
Add kernel parameter to builders
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Oct 8, 2024
1 parent ae55899 commit 9dbe21b
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions trieste/models/gpflow/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def build_sgpr(
trainable_likelihood: bool = False,
num_inducing_points: Optional[int] = None,
trainable_inducing_points: bool = False,
kernel: Optional[gpflow.kernels.Kernel] = None,
) -> SGPR:
"""
Build a :class:`~gpflow.models.SGPR` model with sensible initial parameters and
Expand Down Expand Up @@ -205,11 +206,14 @@ def build_sgpr(
``MAX_NUM_INDUCING_POINTS``, whichever is smaller.
:param trainable_inducing_points: If set to `True` inducing points will be set to
be trainable. This option should be used with caution. By default set to `False`.
:param kernel: The kernel to use in the model, defaults to letting the function set up a
:class:`~gpflow.kernels.Matern52` kernel.
:return: An :class:`~gpflow.models.SGPR` model.
"""
empirical_mean, empirical_variance, _ = _get_data_stats(data)

kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
if kernel is None:
kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
mean = _get_mean_function(empirical_mean)

inducing_points = gpflow.inducing_variables.InducingPoints(
Expand All @@ -228,10 +232,11 @@ def build_sgpr(

def build_vgp_classifier(
data: Dataset,
search_space: SearchSpace,
search_space: Optional[SearchSpace] = None,
kernel_priors: bool = True,
noise_free: bool = False,
kernel_variance: Optional[float] = None,
kernel: Optional[gpflow.kernels.Kernel] = None,
) -> VGP:
"""
Build a :class:`~gpflow.models.VGP` binary classification model with sensible initial
Expand Down Expand Up @@ -264,6 +269,8 @@ def build_vgp_classifier(
certain value. If left unspecified (default), the kernel variance is set to
``CLASSIFICATION_KERNEL_VARIANCE_NOISE_FREE`` in the ``noise_free`` case and to
``CLASSIFICATION_KERNEL_VARIANCE`` otherwise.
:param kernel: The kernel to use in the model, defaults to letting the function set up a
:class:`~gpflow.kernels.Matern52` kernel.
:return: A :class:`~gpflow.models.VGP` model.
"""
if kernel_variance is not None:
Expand All @@ -281,7 +288,13 @@ def build_vgp_classifier(
add_prior_to_variance = kernel_priors

model_likelihood = gpflow.likelihoods.Bernoulli()
kernel = _get_kernel(variance, search_space, kernel_priors, add_prior_to_variance)
if kernel is None:
if search_space is None:
raise ValueError(
"'build_gpr' function requires one of 'search_space' or 'kernel' arguments,"
" but got neither"
)
kernel = _get_kernel(variance, search_space, kernel_priors, add_prior_to_variance)
mean = _get_mean_function(tf.constant(0.0, dtype=gpflow.default_float()))

model = VGP(data.astuple(), kernel, model_likelihood, mean_function=mean)
Expand All @@ -300,6 +313,7 @@ def build_svgp(
trainable_likelihood: bool = False,
num_inducing_points: Optional[int] = None,
trainable_inducing_points: bool = False,
kernel: Optional[gpflow.kernels.Kernel] = None,
) -> SVGP:
"""
Build a :class:`~gpflow.models.SVGP` model with sensible initial parameters and
Expand Down Expand Up @@ -348,6 +362,8 @@ def build_svgp(
``MAX_NUM_INDUCING_POINTS``, whichever is smaller.
:param trainable_inducing_points: If set to `True` inducing points will be set to
be trainable. This option should be used with caution. By default set to `False`.
:param kernel: The kernel to use in the model, defaults to letting the function set up a
:class:`~gpflow.kernels.Matern52` kernel.
:return: An :class:`~gpflow.models.SVGP` model.
"""
empirical_mean, empirical_variance, num_data_points = _get_data_stats(data)
Expand All @@ -359,7 +375,8 @@ def build_svgp(
else:
model_likelihood = gpflow.likelihoods.Gaussian()

kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
if kernel is None:
kernel = _get_kernel(empirical_variance, search_space, kernel_priors, kernel_priors)
mean = _get_mean_function(empirical_mean)

inducing_points = _get_inducing_points(search_space, num_inducing_points)
Expand Down

0 comments on commit 9dbe21b

Please sign in to comment.