From 6e127bb625f112f47249da6196c5d87f76612a60 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Thu, 9 Sep 2021 18:42:46 +0100 Subject: [PATCH 01/19] Initial attempt at creating global inducing layer. Untested. --- gpflux/layers/__init__.py | 1 + gpflux/layers/gi_gp_layer.py | 315 +++++++++++++++++++++++++++++++++++ 2 files changed, 316 insertions(+) create mode 100644 gpflux/layers/gi_gp_layer.py diff --git a/gpflux/layers/__init__.py b/gpflux/layers/__init__.py index 34f7e073..881386b6 100644 --- a/gpflux/layers/__init__.py +++ b/gpflux/layers/__init__.py @@ -19,6 +19,7 @@ from gpflux.layers import basis_functions from gpflux.layers.bayesian_dense_layer import BayesianDenseLayer from gpflux.layers.gp_layer import GPLayer +from gpflux.layers.gi_gp_layer import GIGP from gpflux.layers.latent_variable_layer import LatentVariableLayer, LayerWithObservations from gpflux.layers.likelihood_layer import LikelihoodLayer from gpflux.layers.trackable_layer import TrackableLayer diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py new file mode 100644 index 00000000..fbcf35e1 --- /dev/null +++ b/gpflux/layers/gi_gp_layer.py @@ -0,0 +1,315 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +This module provides :class:`GIGPLayer`, which implements a 'global inducing' point posterior for +a GP layer. Currently restricted to single-output kernels, inducing points, etc... See Ober and +Aitchison (2021) for details. +""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp +import math + +from gpflow import Parameter, default_float +from gpflow.base import TensorType +from gpflow.kernels import SquaredExponential +from gpflow.mean_functions import Identity, MeanFunction +from gpflow.utilities.bijectors import triangular, positive + +from gpflux.sampling.sample import Sample + + +class BatchingSquaredExponential(SquaredExponential): + """Implementation of squared exponential kernel that batches in the following way: given X with + shape [..., N, D], and X2 with shape [..., M, D], we return [..., N, M] instead of the current + behavior, which returns [..., N, ..., M]""" + + def K_r2(self, r2): + return self.variance * tf.exp(-0.5 * r2) + + def scaled_squared_euclid_dist(self, X, X2=None): + X = self.scale(X) + X2 = self.scale(X2) + + if X2 is None: + X2 = X + + Xs = tf.reduce_sum((X**2), -1)[..., :, None] + X2s = tf.reduce_sum((X2**2), -1)[..., None, :] + return Xs + X2s - 2*X@tf.linalg.adjoint(X2s) + + +class GIGPLayer(tf.keras.layers.Layer): + """ + A sparse variational multioutput GP layer. This layer holds the kernel, + inducing variables and variational distribution, and mean function. + """ + + num_data: int + """ + The number of points in the training dataset. This information is used to + obtain the correct scaling between the data-fit and the KL term in the + evidence lower bound (ELBO). + """ + + v: Parameter + r""" + The pseudo-targets. + """ + + L_loc: Parameter + r""" + The lower-triangular Cholesky factor of the precision of ``q(u|v)``. + """ + + L_scale: Parameter + r""" + Scale parameter for L + """ + + def __init__( + self, + num_latent_gps: int, + num_data: int, + num_inducing: int, + input_dim: int, + inducing_targets: Optional[tf.Tensor] = None, + prec_init: Optional[float] = 1., + mean_function: Optional[MeanFunction] = None, + *, + name: Optional[str] = None, + verbose: bool = True, + ): + """ + :param kernel: The multioutput kernel for this layer. + :param inducing_variable: The inducing features for this layer. + :param num_data: The number of points in the training dataset (see :attr:`num_data`). + :param mean_function: The mean function that will be applied to the + inputs. Default: :class:`~gpflow.mean_functions.Identity`. + + .. note:: The Identity mean function requires the input and output + dimensionality of this layer to be the same. If you want to + change the dimensionality in a layer, you may want to provide a + :class:`~gpflow.mean_functions.Linear` mean function instead. + + :param num_samples: The number of samples to draw when converting the + :class:`~tfp.layers.DistributionLambda` into a `tf.Tensor`, see + :meth:`_convert_to_tensor_fn`. Will be stored in the + :attr:`num_samples` attribute. If `None` (the default), draw a + single sample without prefixing the sample shape (see + :class:`tfp.distributions.Distribution`'s `sample() + `_ + method). + :param full_cov: Sets default behaviour of calling this layer + (:attr:`full_cov` attribute): + If `False` (the default), only predict marginals (diagonal + of covariance) with respect to inputs. + If `True`, predict full covariance over inputs. + :param full_output_cov: Sets default behaviour of calling this layer + (:attr:`full_output_cov` attribute): + If `False` (the default), only predict marginals (diagonal + of covariance) with respect to outputs. + If `True`, predict full covariance over outputs. + :param num_latent_gps: The number of (latent) GPs in the layer + (which can be different from the number of outputs, e.g. with a + :class:`~gpflow.kernels.LinearCoregionalization` kernel). + This is used to determine the size of the + variational parameters :attr:`q_mu` and :attr:`q_sqrt`. + If possible, it is inferred from the *kernel* and *inducing_variable*. + :param whiten: If `True` (the default), uses the whitened parameterisation + of the inducing variables; see :attr:`whiten`. + :param name: The name of this layer. + :param verbose: The verbosity mode. Set this parameter to `True` + to show debug information. + """ + + super().__init__( + make_distribution_fn=self._make_distribution_fn, + convert_to_tensor_fn=self._convert_to_tensor_fn, + dtype=default_float(), + name=name, + ) + + self.kernel = BatchingSquaredExponential(lengthscales=[1.]*input_dim) + + self.num_data = num_data + + if mean_function is None: + mean_function = Identity() + if verbose: + warnings.warn( + "Beware, no mean function was specified in the construction of the `GPLayer` " + "so the default `gpflow.mean_functions.Identity` is being used. " + "This mean function will only work if the input dimensionality " + "matches the number of latent Gaussian processes in the layer." + ) + self.mean_function = mean_function + + self.verbose = verbose + + self.num_latent_gps = num_latent_gps + + self.num_inducing = num_inducing + + if inducing_targets is None: + inducing_targets = np.zeros((self.num_latent_gps, num_inducing, 1)) + + self.v = Parameter( + inducing_targets, + dtype=default_float(), + name=f"{self.name}_v" if self.name else "v", + ) + + self.L_loc = Parameter( + np.stack([np.eye(num_inducing) for _ in range(self.num_latent_gps)]), + transform=triangular(), + dtype=default_float(), + name=f"{self.name}_L_loc" if self.name else "L_loc", + ) # [num_latent_gps, num_inducing, num_inducing] + + self.L_scale = Parameter( + np.sqrt(prec_init)*np.ones((self.num_latent_gps, 1, 1)), + transform=positive(), + dtype=default_float(), + name=f"{self.name}_L_scale" if self.name else "L_scale" + ) + + @property + def L(self): + norm = tf.reshape(tf.reduce_mean(tf.linalg.diag_part(self.L_loc), axis=-1), [-1, 1, 1]) + return self.L_loc * self.L_scale/norm + + def mvnormal_log_prob(self, sigma_L, X): + in_features = tf.shape(X)[-2] + out_features = tf.shape(X)[-1] + trace_quad = tf.reduce_sum(tf.linalg.triangular_solve(sigma_L, X)**2, [-1, -2]) + logdet_term = 2.0*tf.reduce_sum(tf.math.log(tf.linalg.diag_part(sigma_L)), -1) + return -0.5*trace_quad - 0.5*out_features*(logdet_term + in_features*math.log(2*math.pi)) + + def call( + self, + inputs: TensorType, + *args: List[Any], + **kwargs: Dict[str, Any] + ) -> tf.Tensor: + """ + Sample-based propagation of both inducing points and function values. + """ + assert len(tf.shape(inputs)) == 3 + + mean_function = self.mean_function(inputs) + + Kuu = self.kernel(inputs[..., :self.num_inducing]) + Kuf = self.kernel(inputs[..., :self.num_inducing, self.num_inducing:]) + Kfu = tf.linalg.adjoint(Kuf) + Kff = self.kernel.K_diag(inputs[..., self.num_inducing:]) + + Kuu, Kuf, Kfu = tf.expand_dims(Kuu, 1), tf.expand_dims(Kuf, 1), tf.expand_dims(Kfu, 1) + + (S, _, _, _) = tf.shape(Kuu) + S = S.numpy() + + Iuu = tf.eye(self.num_inducing, dtype=default_float()) + + L = self.L + LT = tf.linalg.adjoint(L) + + KuuL = Kuu @ L + + lKlpI = LT @ KuuL + Iuu + chol_lKlpI = tf.linalg.cholesky(lKlpI) + Sigma = Kuu - KuuL @ tf.linalg.cholesky_solve(chol_lKlpI, tf.linalg.adjoint(KuuL)) + + eps_1 = tf.random.normal( + [S, self.num_latent_gps, self.num_inducing, 1], + dtype=default_float() + ) + eps_2 = tf.random.normal( + [S, self.num_latent_gps, self.num_inducing, 1], + dtype=default_float() + ) + + chol_Kuu = tf.linalg.cholesky(Kuu) + chol_Kuu_T = tf.linalg.adjoint(chol_Kuu) + + inv_Kuu_noise = tf.linalg.triangular_solve(chol_Kuu_T, eps_1, lower=False) + L_noise = L @ eps_2 + prec_noise = inv_Kuu_noise + L_noise + u = Sigma @ ((L @ LT) @ self.v + prec_noise) + + if kwargs.get("training"): + loss_per_datapoint = self.prior_kl(LT, chol_lKlpI, u) / self.num_data + else: + # TF quirk: add_loss must always add a tensor to compile + loss_per_datapoint = tf.constant(0.0, dtype=default_float()) + self.add_loss(loss_per_datapoint) + + # Metric names should be unique; otherwise they get overwritten if you + # have multiple with the same name + name = f"{self.name}_prior_kl" if self.name else "prior_kl" + self.add_metric(loss_per_datapoint, name=name, aggregation="mean") + + #### f|u + Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) + Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) + Vf = Kff - tf.squeeze(tf.reduce_sum((Kfu_invKuu*Kfu), -1), 1) + + eps_f = tf.random.normal( + tf.shape(Ef), + dtype=default_float() + ) + + f_samples = Ef + tf.math.sqrt(Vf)[..., None]*eps_f + + all_samples = tf.concat( + [ + tf.linalg.adjoint(tf.squeeze(u, -1)), + f_samples, + ], + axis=-2 + ) + + return all_samples + mean_function + + def prior_kl( + self, + LT: tf.Tensor, + chol_lKlpI: tf.Tensor, + u: tf.Tensor, + ) -> tf.Tensor: + r""" + Returns the KL divergence ``KL[q(u)∥p(u)]`` from the prior ``p(u)`` to + the variational distribution ``q(u)``. If this layer uses the + :attr:`whiten`\ ed representation, returns ``KL[q(v)∥p(v)]``. + """ + lv = LT @ self.v + + logP = tf.reduce_sum(self.mvnormal_log_prob(chol_lKlpI, lv), -1) + logQ = tf.reduce_sum(tfp.distributions.Normal(LT@u, 1.).log_prob(lv), [-1, -2, -3]) + + logpq = logP - logQ + + return -tf.reduce_mean(logpq) + + def sample(self) -> Sample: + """ + .. todo:: TODO: Document this. + """ + raise NotImplementedError("TODO") From bdfef7c5bd747d43f0b72c4f55a4c56bc403de08 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Fri, 10 Sep 2021 11:28:21 +0100 Subject: [PATCH 02/19] First attempt at full GI DGP model. Untested. --- gpflux/layers/__init__.py | 4 +- gpflux/layers/gi_gp_layer.py | 2 +- gpflux/layers/likelihood_layer.py | 36 +++- gpflux/models/gi_deep_gp.py | 301 ++++++++++++++++++++++++++++++ 4 files changed, 339 insertions(+), 4 deletions(-) create mode 100644 gpflux/models/gi_deep_gp.py diff --git a/gpflux/layers/__init__.py b/gpflux/layers/__init__.py index 881386b6..a7b83072 100644 --- a/gpflux/layers/__init__.py +++ b/gpflux/layers/__init__.py @@ -19,7 +19,7 @@ from gpflux.layers import basis_functions from gpflux.layers.bayesian_dense_layer import BayesianDenseLayer from gpflux.layers.gp_layer import GPLayer -from gpflux.layers.gi_gp_layer import GIGP +from gpflux.layers.gi_gp_layer import GIGPLayer from gpflux.layers.latent_variable_layer import LatentVariableLayer, LayerWithObservations -from gpflux.layers.likelihood_layer import LikelihoodLayer +from gpflux.layers.likelihood_layer import LikelihoodLayer, SampleBasedGaussianLikelihoodLayer from gpflux.layers.trackable_layer import TrackableLayer diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index fbcf35e1..883a1019 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -266,7 +266,7 @@ def call( name = f"{self.name}_prior_kl" if self.name else "prior_kl" self.add_metric(loss_per_datapoint, name=name, aggregation="mean") - #### f|u + #### f|u - possibility of using a conditional here? Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) Vf = Kff - tf.squeeze(tf.reduce_sum((Kfu_invKuu*Kfu), -1), 1) diff --git a/gpflux/layers/likelihood_layer.py b/gpflux/layers/likelihood_layer.py index 85d68777..e57c3408 100644 --- a/gpflux/layers/likelihood_layer.py +++ b/gpflux/layers/likelihood_layer.py @@ -23,13 +23,47 @@ import tensorflow_probability as tfp from tensorflow_probability.python.util.deferred_tensor import TensorMetaClass -from gpflow import default_float +from gpflow import default_float, Parameter from gpflow.base import TensorType from gpflow.likelihoods import Likelihood +from gpflow.utilities.bijectors import positive from gpflux.layers.trackable_layer import TrackableLayer +class SampleBasedGaussianLikelihoodLayer(TrackableLayer): + def __init__(self, variance: float = 1., variance_lower_bound=1e-8): + super().__init__(dtype=default_float()) + + if variance <= variance_lower_bound: + raise ValueError( + f"The variance of the Gaussian likelihood must be strictly greater than " + f"{variance_lower_bound}" + ) + + self.variance = Parameter(variance, transform=positive(lower=variance_lower_bound)) + + def call( + self, + inputs: TensorType, + targets: Optional[TensorType] = None, + training: bool = None, + ) -> tfp.distributions.Normal: + likelihood_dist = tfp.distributions.Normal(inputs, self.variance, dtype=default_float()) + + if training: + assert targets is not None + loss_per_datapoint = tf.reduce_mean( + -likelihood_dist.log_prob(tf.expand_dims(targets, 0)) + ) + else: + loss_per_datapoint = tf.constant(0.0, dtype=default_float()) + + self.add_loss(loss_per_datapoint) + + return likelihood_dist + + class LikelihoodLayer(TrackableLayer): r""" A Keras layer that wraps a GPflow :class:`~gpflow.likelihoods.Likelihood`. This layer expects a diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py new file mode 100644 index 00000000..8113cb3f --- /dev/null +++ b/gpflux/models/gi_deep_gp.py @@ -0,0 +1,301 @@ +# +# Copyright (c) 2021 The GPflux Contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" This module provides an implementation of global inducing points for deep GPs. """ + +import itertools +from typing import List, Optional, Tuple, Type, Union + +import tensorflow as tf + +import gpflow +from gpflow import Parameter, default_float +from gpflow.base import Module, TensorType + +import gpflux +from gpflux.layers import LayerWithObservations, LikelihoodLayer, SampleBasedGaussianLikelihoodLayer +from gpflux.sampling.sample import Sample + + +class GIDeepGP(Module): + + f_layers: List[tf.keras.layers.Layer] + """ A list of all layers in this DeepGP (just :attr:`likelihood_layer` is separate). """ + + likelihood_layer: gpflux.layers.LikelihoodLayer + """ The likelihood layer. """ + + """ + The default for the *model_class* argument of :meth:`as_training_model` and + :meth:`as_prediction_model`. This must have the same semantics as `tf.keras.Model`, + that is, it must accept a list of inputs and an output. This could be + `tf.keras.Model` itself or `gpflux.optimization.NatGradModel` (but not, for + example, `tf.keras.Sequential`). + """ + + num_data: int + """ + The number of points in the training dataset. This information is used to + obtain correct scaling between the data-fit and the KL term in the evidence + lower bound (:meth:`elbo`). + """ + + def __init__( + self, + f_layers: List[tf.keras.layers.Layer], + num_inducing: int, + *, + likelihood_var: Optional[float] = 1.0, + inducing_init: Optional[tf.Tensor] = None, + inducing_shape: Optional[List[int]] = None, + input_dim: Optional[int] = None, + target_dim: Optional[int] = None, + default_model_class: Type[tf.keras.Model] = tf.keras.Model, + num_data: Optional[int] = None, + num_train_samples: Optional[int] = 10, + num_test_samples: Optional[int] = 100, + ): + """ + :param f_layers: The layers ``[f₁, f₂, …, fₙ]`` describing the latent + function ``f(x) = fₙ(⋯ (f₂(f₁(x))))``. + :param likelihood: The layer for the likelihood ``p(y|f)``. If this is a + GPflow likelihood, it will be wrapped in a :class:`~gpflux.layers.LikelihoodLayer`. + Alternatively, you can provide a :class:`~gpflux.layers.LikelihoodLayer` explicitly. + :param input_dim: The input dimensionality. + :param target_dim: The target dimensionality. + :param default_model_class: The default for the *model_class* argument of + :meth:`as_training_model` and :meth:`as_prediction_model`; + see the :attr:`default_model_class` attribute. + :param num_data: The number of points in the training dataset; see the + :attr:`num_data` attribute. + If you do not specify a value for this parameter explicitly, it is automatically + detected from the :attr:`~gpflux.layers.GPLayer.num_data` attribute in the GP layers. + """ + self.inputs = tf.keras.Input((input_dim,), name="inputs") + self.targets = tf.keras.Input((target_dim,), name="targets") + self.f_layers = f_layers + self.likelihood_layer = SampleBasedGaussianLikelihoodLayer(variance=likelihood_var) + self.num_inducing = num_inducing + self.default_model_class = default_model_class + self.num_data = self._validate_num_data(f_layers, num_data) + + assert (inducing_init is not None) != (inducing_shape is not None) + + if inducing_init is None: + self.inducing_data = Parameter(inducing_shape, dtype=default_float()) + else: + self.inducing_data = Parameter(inducing_init, dtype=default_float()) + + assert num_inducing == tf.shape(self.inducing_data)[0] + self.rank = 1 + len(tf.shape(self.inducing_data)) + + self.num_train_samples = num_train_samples + self.num_test_samples = num_test_samples + + @staticmethod + def _validate_num_data( + f_layers: List[tf.keras.layers.Layer], num_data: Optional[int] = None + ) -> int: + """ + Check that the :attr:`~gpflux.layers.gp_layer.GPLayer.num_data` + attributes of all layers in *f_layers* are consistent with each other + and with the (optional) *num_data* argument. + + :returns: The validated number of datapoints. + """ + for i, layer in enumerate(f_layers): + layer_num_data = getattr(layer, "num_data", None) + if num_data is None: + num_data = layer_num_data + else: + if layer_num_data is not None and num_data != layer_num_data: + raise ValueError( + f"f_layers[{i}].num_data is inconsistent with num_data={num_data}" + ) + if num_data is None: + raise ValueError("Could not determine num_data; please provide explicitly") + return num_data + + def _inducing_add(self, inputs: TensorType): + assert self.rank == len(tf.shape(inputs)) + + inducing_data = tf.tile( + tf.expand_dims(self.inducing_data, 0), + [tf.shape(inputs)[0], *tf.shape(self.inducing_data)] + ) + x = tf.concat([inducing_data, inputs], 1) + + return x + + def _inducing_remove(self, inputs: TensorType): + return inputs[:, self.num_inducing:] + + def _evaluate_deep_gp( + self, + inputs: TensorType, + targets: Optional[TensorType], + training: Optional[bool] = None, + ) -> tf.Tensor: + """ + Evaluate ``f(x) = fₙ(⋯ (f₂(f₁(x))))`` on the *inputs* argument. + + Layers that inherit from :class:`~gpflux.layers.LayerWithObservations` + are passed the additional keyword argument ``observations=[inputs, + targets]`` if *targets* contains a value, or ``observations=None`` when + *targets* is `None`. + """ + features = inputs + + # NOTE: we cannot rely on the `training` flag here, as the correct + # symbolic graph needs to be constructed at "build" time (before either + # fit() or predict() get called). + if targets is not None: + observations = [inputs, targets] + num_samples = self.num_train_samples + else: + # TODO would it be better to simply pass [inputs, None] in this case? + observations = None + num_samples = self.num_test_samples + + features = tf.tile(tf.expand_dims(features, 0), [num_samples, *[1]*len(tf.shape(features))]) + features = self._inducing_add(features) + + for layer in self.f_layers: + if isinstance(layer, LayerWithObservations): + raise NotImplementedError("Latent variable layers not yet supported") + else: + features = layer(features, training=training) + return self._inducing_remove(features) + + def _evaluate_likelihood( + self, + f_outputs: TensorType, + targets: Optional[TensorType], + training: Optional[bool] = None, + ) -> tf.Tensor: + """ + Call the `likelihood_layer` on *f_outputs*, which adds the + corresponding layer loss when training. + """ + return self.likelihood_layer(f_outputs, targets=targets, training=training) + + def call( + self, + inputs: TensorType, + targets: Optional[TensorType] = None, + training: Optional[bool] = None, + ) -> tf.Tensor: + f_outputs = self._evaluate_deep_gp(inputs, targets=targets, training=training) + y_outputs = self._evaluate_likelihood(f_outputs, targets=targets, training=training) + return y_outputs + + # def predict_f(self, inputs: TensorType) -> Tuple[tf.Tensor, tf.Tensor]: + # """ + # :returns: The mean and variance (not the scale!) of ``f``, for compatibility with GPflow + # models. + # + # .. note:: This method does **not** support ``full_cov`` or ``full_output_cov``. + # """ + # f_distribution = self._evaluate_deep_gp(inputs, targets=None) + # return f_distribution.loc, f_distribution.scale.diag ** 2 + + def elbo(self, data: Tuple[TensorType, TensorType]) -> tf.Tensor: + """ + :returns: The ELBO (not the per-datapoint loss!), for compatibility with GPflow models. + """ + X, Y = data + _ = self.call(X, Y, training=True) + all_losses = [ + loss + for layer in itertools.chain(self.f_layers, [self.likelihood_layer]) + for loss in layer.losses + ] + return -tf.reduce_sum(all_losses) * self.num_data + + def _get_model_class(self, model_class: Optional[Type[tf.keras.Model]]) -> Type[tf.keras.Model]: + if model_class is not None: + return model_class + else: + return self.default_model_class + + def as_training_model( + self, model_class: Optional[Type[tf.keras.Model]] = None + ) -> tf.keras.Model: + r""" + Construct a `tf.keras.Model` instance that requires you to provide both ``inputs`` + and ``targets`` to its call. This information is required for + training the model, because the ``targets`` need to be passed to the `likelihood_layer` (and + to :class:`~gpflux.layers.LayerWithObservations` instances such as + :class:`~gpflux.layers.LatentVariableLayer`\ s, if present). + + When compiling the returned model, do **not** provide any additional + losses (this is handled by the :attr:`likelihood_layer`). + + Train with + + .. code-block:: python + + model.compile(optimizer) # do NOT pass a loss here + model.fit({"inputs": X, "targets": Y}, ...) + + See `Keras's Endpoint layer pattern + `_ + for more details. + + .. note:: Use `as_prediction_model` if you want only to predict, and do not want to pass in + a dummy array for the targets. + + :param model_class: The model class to use; overrides `default_model_class`. + """ + model_class = self._get_model_class(model_class) + outputs = self.call(self.inputs, self.targets) + return model_class([self.inputs, self.targets], outputs) + + def as_prediction_model( + self, model_class: Optional[Type[tf.keras.Model]] = None + ) -> tf.keras.Model: + """ + Construct a `tf.keras.Model` instance that requires only ``inputs``, + which means you do not have to provide dummy target values when + predicting at test points. + + Predict with + + .. code-block:: python + + model.predict(Xtest, ...) + + .. note:: The returned model will not support training; for that, use `as_training_model`. + + :param model_class: The model class to use; overrides `default_model_class`. + """ + model_class = self._get_model_class(model_class) + outputs = self.call(self.inputs) + return model_class(self.inputs, outputs) + + +def sample_dgp(model: GIDeepGP) -> Sample: # TODO: should this be part of a [Vanilla]DeepGP class? + function_draws = [layer.sample() for layer in model.f_layers] + # TODO: error check that all layers implement .sample()? + + class ChainedSample(Sample): + """ This class chains samples from consecutive layers. """ + + def __call__(self, X: TensorType) -> tf.Tensor: + for f in function_draws: + X = f(X) + return X + + return ChainedSample() From 4185a9249ec754576abbd51b6c890e724cf64489 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Fri, 10 Sep 2021 15:58:05 +0100 Subject: [PATCH 03/19] Bug fixes --- gpflux/layers/gi_gp_layer.py | 23 +++++++++++------------ gpflux/layers/likelihood_layer.py | 2 +- gpflux/models/__init__.py | 1 + gpflux/models/gi_deep_gp.py | 26 +++++++++++++------------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 883a1019..6f456604 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -22,6 +22,7 @@ import warnings from typing import Any, Dict, List, Optional, Tuple +import gpflow import numpy as np import tensorflow as tf import tensorflow_probability as tfp @@ -53,7 +54,7 @@ def scaled_squared_euclid_dist(self, X, X2=None): Xs = tf.reduce_sum((X**2), -1)[..., :, None] X2s = tf.reduce_sum((X2**2), -1)[..., None, :] - return Xs + X2s - 2*X@tf.linalg.adjoint(X2s) + return Xs + X2s - 2*X@tf.linalg.adjoint(X2) class GIGPLayer(tf.keras.layers.Layer): @@ -141,14 +142,14 @@ def __init__( """ super().__init__( - make_distribution_fn=self._make_distribution_fn, - convert_to_tensor_fn=self._convert_to_tensor_fn, dtype=default_float(), name=name, ) self.kernel = BatchingSquaredExponential(lengthscales=[1.]*input_dim) + self.input_dim = input_dim + self.num_data = num_data if mean_function is None: @@ -197,8 +198,8 @@ def L(self): return self.L_loc * self.L_scale/norm def mvnormal_log_prob(self, sigma_L, X): - in_features = tf.shape(X)[-2] - out_features = tf.shape(X)[-1] + in_features = self.input_dim + out_features = self.num_latent_gps trace_quad = tf.reduce_sum(tf.linalg.triangular_solve(sigma_L, X)**2, [-1, -2]) logdet_term = 2.0*tf.reduce_sum(tf.math.log(tf.linalg.diag_part(sigma_L)), -1) return -0.5*trace_quad - 0.5*out_features*(logdet_term + in_features*math.log(2*math.pi)) @@ -212,19 +213,16 @@ def call( """ Sample-based propagation of both inducing points and function values. """ - assert len(tf.shape(inputs)) == 3 - mean_function = self.mean_function(inputs) - Kuu = self.kernel(inputs[..., :self.num_inducing]) - Kuf = self.kernel(inputs[..., :self.num_inducing, self.num_inducing:]) + Kuu = self.kernel(inputs[..., :self.num_inducing, :]) + Kuf = self.kernel(inputs[..., :self.num_inducing, :], inputs[..., self.num_inducing:, :]) Kfu = tf.linalg.adjoint(Kuf) - Kff = self.kernel.K_diag(inputs[..., self.num_inducing:]) + Kff = self.kernel.K_diag(inputs[..., self.num_inducing:, :]) Kuu, Kuf, Kfu = tf.expand_dims(Kuu, 1), tf.expand_dims(Kuf, 1), tf.expand_dims(Kfu, 1) - (S, _, _, _) = tf.shape(Kuu) - S = S.numpy() + S = tf.shape(Kuu)[0] Iuu = tf.eye(self.num_inducing, dtype=default_float()) @@ -246,6 +244,7 @@ def call( dtype=default_float() ) + Kuu = Kuu + gpflow.default_jitter()*tf.eye(self.num_inducing, dtype=default_float()) chol_Kuu = tf.linalg.cholesky(Kuu) chol_Kuu_T = tf.linalg.adjoint(chol_Kuu) diff --git a/gpflux/layers/likelihood_layer.py b/gpflux/layers/likelihood_layer.py index e57c3408..ef6ef69d 100644 --- a/gpflux/layers/likelihood_layer.py +++ b/gpflux/layers/likelihood_layer.py @@ -49,7 +49,7 @@ def call( targets: Optional[TensorType] = None, training: bool = None, ) -> tfp.distributions.Normal: - likelihood_dist = tfp.distributions.Normal(inputs, self.variance, dtype=default_float()) + likelihood_dist = tfp.distributions.Normal(inputs, tf.sqrt(self.variance)) if training: assert targets is not None diff --git a/gpflux/models/__init__.py b/gpflux/models/__init__.py index 479c4c57..176b1819 100644 --- a/gpflux/models/__init__.py +++ b/gpflux/models/__init__.py @@ -17,3 +17,4 @@ Base model classes implemented in GPflux """ from gpflux.models.deep_gp import DeepGP +from gpflux.models.gi_deep_gp import GIDeepGP diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index 8113cb3f..5928addd 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -19,6 +19,7 @@ from typing import List, Optional, Tuple, Type, Union import tensorflow as tf +import tensorflow_probability as tfp import gpflow from gpflow import Parameter, default_float @@ -129,11 +130,10 @@ def _validate_num_data( return num_data def _inducing_add(self, inputs: TensorType): - assert self.rank == len(tf.shape(inputs)) inducing_data = tf.tile( tf.expand_dims(self.inducing_data, 0), - [tf.shape(inputs)[0], *tf.shape(self.inducing_data)] + [tf.shape(inputs)[0], *[1]*(self.rank-1)] ) x = tf.concat([inducing_data, inputs], 1) @@ -169,7 +169,8 @@ def _evaluate_deep_gp( observations = None num_samples = self.num_test_samples - features = tf.tile(tf.expand_dims(features, 0), [num_samples, *[1]*len(tf.shape(features))]) + features = tf.tile(tf.expand_dims(features, 0), + [num_samples, *[1]*(self.rank-1)]) features = self._inducing_add(features) for layer in self.f_layers: @@ -196,20 +197,19 @@ def call( inputs: TensorType, targets: Optional[TensorType] = None, training: Optional[bool] = None, - ) -> tf.Tensor: + ) -> tfp.distributions.Distribution: f_outputs = self._evaluate_deep_gp(inputs, targets=targets, training=training) y_outputs = self._evaluate_likelihood(f_outputs, targets=targets, training=training) return y_outputs - # def predict_f(self, inputs: TensorType) -> Tuple[tf.Tensor, tf.Tensor]: - # """ - # :returns: The mean and variance (not the scale!) of ``f``, for compatibility with GPflow - # models. - # - # .. note:: This method does **not** support ``full_cov`` or ``full_output_cov``. - # """ - # f_distribution = self._evaluate_deep_gp(inputs, targets=None) - # return f_distribution.loc, f_distribution.scale.diag ** 2 + def predict_f(self, inputs: TensorType) -> Tuple[tf.Tensor, tf.Tensor]: + """ + :returns: The mean and variance (not the scale!) of ``f``, for compatibility with GPflow + models. + + .. note:: This method does **not** support ``full_cov`` or ``full_output_cov``. + """ + raise NotImplementedError("TODO") def elbo(self, data: Tuple[TensorType, TensorType]) -> tf.Tensor: """ From e747c592e17578f3882a96e5ba343fe555bc68a6 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Fri, 10 Sep 2021 16:59:52 +0100 Subject: [PATCH 04/19] Added consistent sampling --- gpflux/layers/gi_gp_layer.py | 56 ++++++++++++++++++++++++++---------- gpflux/models/gi_deep_gp.py | 13 +++++++++ 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 6f456604..d155ee42 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -28,7 +28,7 @@ import tensorflow_probability as tfp import math -from gpflow import Parameter, default_float +from gpflow import Parameter, default_float, default_jitter from gpflow.base import TensorType from gpflow.kernels import SquaredExponential from gpflow.mean_functions import Identity, MeanFunction @@ -244,7 +244,7 @@ def call( dtype=default_float() ) - Kuu = Kuu + gpflow.default_jitter()*tf.eye(self.num_inducing, dtype=default_float()) + Kuu = Kuu + default_jitter()*tf.eye(self.num_inducing, dtype=default_float()) chol_Kuu = tf.linalg.cholesky(Kuu) chol_Kuu_T = tf.linalg.adjoint(chol_Kuu) @@ -265,17 +265,10 @@ def call( name = f"{self.name}_prior_kl" if self.name else "prior_kl" self.add_metric(loss_per_datapoint, name=name, aggregation="mean") - #### f|u - possibility of using a conditional here? - Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) - Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) - Vf = Kff - tf.squeeze(tf.reduce_sum((Kfu_invKuu*Kfu), -1), 1) - - eps_f = tf.random.normal( - tf.shape(Ef), - dtype=default_float() - ) - - f_samples = Ef + tf.math.sqrt(Vf)[..., None]*eps_f + if kwargs.get("full_cov"): + f_samples = self.sample_conditional(u, Kff, Kuf, chol_Kuu, inputs=inputs, full_cov=True) + else: + f_samples = self.sample_conditional(u, Kff, Kuf, chol_Kuu) all_samples = tf.concat( [ @@ -287,6 +280,38 @@ def call( return all_samples + mean_function + def sample_conditional( + self, + u: TensorType, + Kff: TensorType, + Kuf: TensorType, + chol_Kuu: TensorType, + inputs: Optional[TensorType] = None, + full_cov: bool = False, + ) -> tf.Tensor: + Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) + Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) + + eps_f = tf.random.normal( + tf.shape(Ef), + dtype=default_float() + ) + + if full_cov: + assert inputs is not None + Kff = self.kernel(inputs[..., self.num_inducing:, :]) + Vf = Kff - tf.squeeze(Kfu_invKuu @ Kuf, 1) + Vf = Vf + default_jitter()*tf.eye(tf.shape(Vf)[-1], dtype=default_float()) + chol_Vf = tf.linalg.cholesky(Vf) + + var_part = chol_Vf @ eps_f + else: + Vf = Kff - tf.squeeze(tf.reduce_sum((Kfu_invKuu*tf.linalg.adjoint(Kuf)), -1), 1) + + var_part = tf.math.sqrt(Vf)[..., None]*eps_f + + return Ef + var_part + def prior_kl( self, LT: tf.Tensor, @@ -307,8 +332,9 @@ def prior_kl( return -tf.reduce_mean(logpq) - def sample(self) -> Sample: + def sample(self, inputs: TensorType) -> tf.Tensor: """ .. todo:: TODO: Document this. """ - raise NotImplementedError("TODO") + + return self.call(inputs, kwargs={"training": None, "full_cov": True}) diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index 5928addd..2952ab5b 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -285,6 +285,19 @@ def as_prediction_model( outputs = self.call(self.inputs) return model_class(self.inputs, outputs) + def sample(self, inputs: TensorType, num_samples: int) -> tf.Tensor: + features = tf.tile(tf.expand_dims(inputs, 0), + [num_samples, *[1]*(self.rank-1)]) + features = self._inducing_add(features) + + for layer in self.f_layers: + if isinstance(layer, LayerWithObservations): + raise NotImplementedError("Latent variable layers not yet supported") + else: + features = layer(features, training=None, full_cov=True) + + return self._inducing_remove(features) + def sample_dgp(model: GIDeepGP) -> Sample: # TODO: should this be part of a [Vanilla]DeepGP class? function_draws = [layer.sample() for layer in model.f_layers] From f359bb97a26af762b6e55f618e4ba107df7eb2b4 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Mon, 13 Sep 2021 14:45:13 +0100 Subject: [PATCH 05/19] Added bool for consistent vs non-consistent sampling --- gpflux/models/gi_deep_gp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index 2952ab5b..3cc7055c 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -285,7 +285,7 @@ def as_prediction_model( outputs = self.call(self.inputs) return model_class(self.inputs, outputs) - def sample(self, inputs: TensorType, num_samples: int) -> tf.Tensor: + def sample(self, inputs: TensorType, num_samples: int, consistent: bool = False) -> tf.Tensor: features = tf.tile(tf.expand_dims(inputs, 0), [num_samples, *[1]*(self.rank-1)]) features = self._inducing_add(features) @@ -294,7 +294,7 @@ def sample(self, inputs: TensorType, num_samples: int) -> tf.Tensor: if isinstance(layer, LayerWithObservations): raise NotImplementedError("Latent variable layers not yet supported") else: - features = layer(features, training=None, full_cov=True) + features = layer(features, training=None, full_cov=consistent) return self._inducing_remove(features) From 75f13b2b705e5494c4706391c4886fc015158475 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Tue, 14 Sep 2021 14:06:24 +0100 Subject: [PATCH 06/19] Added kernel variance init --- gpflux/layers/gi_gp_layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index d155ee42..459f905c 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -94,6 +94,7 @@ def __init__( inducing_targets: Optional[tf.Tensor] = None, prec_init: Optional[float] = 1., mean_function: Optional[MeanFunction] = None, + kernel_variance_init: Optional[float] = 1., *, name: Optional[str] = None, verbose: bool = True, @@ -146,7 +147,8 @@ def __init__( name=name, ) - self.kernel = BatchingSquaredExponential(lengthscales=[1.]*input_dim) + self.kernel = BatchingSquaredExponential( + lengthscales=[1.]*input_dim, variance=kernel_variance_init) self.input_dim = input_dim @@ -186,7 +188,7 @@ def __init__( ) # [num_latent_gps, num_inducing, num_inducing] self.L_scale = Parameter( - np.sqrt(prec_init)*np.ones((self.num_latent_gps, 1, 1)), + tf.sqrt(self.kernel.variance)*np.sqrt(prec_init)*np.ones((self.num_latent_gps, 1, 1)), transform=positive(), dtype=default_float(), name=f"{self.name}_L_scale" if self.name else "L_scale" From a119826c9f3d76f68b20269c343059d10398c7df Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Tue, 14 Sep 2021 14:07:00 +0100 Subject: [PATCH 07/19] Fixed bug in init for inducing points --- gpflux/models/gi_deep_gp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index 3cc7055c..a7fe1a0a 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -95,7 +95,7 @@ def __init__( assert (inducing_init is not None) != (inducing_shape is not None) if inducing_init is None: - self.inducing_data = Parameter(inducing_shape, dtype=default_float()) + self.inducing_data = Parameter(tf.random.normal(inducing_shape), dtype=default_float()) else: self.inducing_data = Parameter(inducing_init, dtype=default_float()) From d2788e8fbfe595639e35f17349212be6b37dd711 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Tue, 14 Sep 2021 15:55:28 +0100 Subject: [PATCH 08/19] Documentation for GI GP layer --- gpflux/layers/gi_gp_layer.py | 114 +++++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 45 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 459f905c..a6093740 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -15,14 +15,13 @@ # """ This module provides :class:`GIGPLayer`, which implements a 'global inducing' point posterior for -a GP layer. Currently restricted to single-output kernels, inducing points, etc... See Ober and -Aitchison (2021) for details. +a GP layer. Currently restricted to squared exponential kernel, inducing points, etc... See Ober and +Aitchison (2021): https://arxiv.org/abs/2005.08140 for details. """ import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional -import gpflow import numpy as np import tensorflow as tf import tensorflow_probability as tfp @@ -34,8 +33,6 @@ from gpflow.mean_functions import Identity, MeanFunction from gpflow.utilities.bijectors import triangular, positive -from gpflux.sampling.sample import Sample - class BatchingSquaredExponential(SquaredExponential): """Implementation of squared exponential kernel that batches in the following way: given X with @@ -77,20 +74,20 @@ class GIGPLayer(tf.keras.layers.Layer): L_loc: Parameter r""" - The lower-triangular Cholesky factor of the precision of ``q(u|v)``. + The lower-triangular Cholesky factor of the precision of ``q(v|u)``. """ L_scale: Parameter r""" - Scale parameter for L + Scale parameter for L. """ def __init__( self, + input_dim: int, num_latent_gps: int, num_data: int, num_inducing: int, - input_dim: int, inducing_targets: Optional[tf.Tensor] = None, prec_init: Optional[float] = 1., mean_function: Optional[MeanFunction] = None, @@ -100,9 +97,16 @@ def __init__( verbose: bool = True, ): """ - :param kernel: The multioutput kernel for this layer. - :param inducing_variable: The inducing features for this layer. + :param input_dim: The dimension of the input for this layer. + :param num_latent_gps: The number of latent GPs in this layer (i.e. the output dimension). + Unlike for the :class:`GPLayer`, this must be provided for global inducing layers. :param num_data: The number of points in the training dataset (see :attr:`num_data`). + :param num_inducing: The number of inducing points; for global inducing this should be the + same for the whole model. + :param inducing_targets: An optional initialization for `v`. The most useful case for this + is the last layer, where it may be initialized to (a subset of) the output data. + :param prec_init: Initialization for the precision parameter. See Ober and Aitchison (2021) + for more details. :param mean_function: The mean function that will be applied to the inputs. Default: :class:`~gpflow.mean_functions.Identity`. @@ -111,32 +115,7 @@ def __init__( change the dimensionality in a layer, you may want to provide a :class:`~gpflow.mean_functions.Linear` mean function instead. - :param num_samples: The number of samples to draw when converting the - :class:`~tfp.layers.DistributionLambda` into a `tf.Tensor`, see - :meth:`_convert_to_tensor_fn`. Will be stored in the - :attr:`num_samples` attribute. If `None` (the default), draw a - single sample without prefixing the sample shape (see - :class:`tfp.distributions.Distribution`'s `sample() - `_ - method). - :param full_cov: Sets default behaviour of calling this layer - (:attr:`full_cov` attribute): - If `False` (the default), only predict marginals (diagonal - of covariance) with respect to inputs. - If `True`, predict full covariance over inputs. - :param full_output_cov: Sets default behaviour of calling this layer - (:attr:`full_output_cov` attribute): - If `False` (the default), only predict marginals (diagonal - of covariance) with respect to outputs. - If `True`, predict full covariance over outputs. - :param num_latent_gps: The number of (latent) GPs in the layer - (which can be different from the number of outputs, e.g. with a - :class:`~gpflow.kernels.LinearCoregionalization` kernel). - This is used to determine the size of the - variational parameters :attr:`q_mu` and :attr:`q_sqrt`. - If possible, it is inferred from the *kernel* and *inducing_variable*. - :param whiten: If `True` (the default), uses the whitened parameterisation - of the inducing variables; see :attr:`whiten`. + :param kernel_variance_init: Initialization for the kernel variance :param name: The name of this layer. :param verbose: The verbosity mode. Set this parameter to `True` to show debug information. @@ -147,6 +126,11 @@ def __init__( name=name, ) + if kernel_variance_init <= 0: + raise ValueError("Kernel variance must be positive.") + if prec_init <= 0: + raise ValueError("Precision init must be positive") + self.kernel = BatchingSquaredExponential( lengthscales=[1.]*input_dim, variance=kernel_variance_init) @@ -173,6 +157,10 @@ def __init__( if inducing_targets is None: inducing_targets = np.zeros((self.num_latent_gps, num_inducing, 1)) + elif tf.rank(inducing_targets) == 2: + inducing_targets = tf.expand_dims(tf.linalg.adjoint(inducing_targets), -1) + if inducing_targets.shape != (self.num_latent_gps, num_inducing, 1): + raise ValueError("Incorrect shape was provided for the inducing targets.") self.v = Parameter( inducing_targets, @@ -195,11 +183,24 @@ def __init__( ) @property - def L(self): + def L(self) -> tf.Tensor: + """ + :return: the Cholesky of the precision hyperparameter. We parameterize L using L_loc and + L_scale to achieve greater stability during optimization. + """ norm = tf.reshape(tf.reduce_mean(tf.linalg.diag_part(self.L_loc), axis=-1), [-1, 1, 1]) return self.L_loc * self.L_scale/norm - def mvnormal_log_prob(self, sigma_L, X): + def mvnormal_log_prob(self, sigma_L: TensorType, X: TensorType) -> tf.Tensor: + """ + Calculates the log probability of a zero-mean multivariate Gaussian with covariance sigma + and evaluation points X, with batching of both the covariance and X. + + TODO: look into whether this can be replaced with a tfp.distributions.Distribution + :param sigma_L: Cholesky of covariance sigma, shape [..., 1, D, D] + :param X: evaluation point for log_prob, shape [..., M, D, 1] + :return: the log probability, shape [..., M] + """ in_features = self.input_dim out_features = self.num_latent_gps trace_quad = tf.reduce_sum(tf.linalg.triangular_solve(sigma_L, X)**2, [-1, -2]) @@ -213,7 +214,8 @@ def call( **kwargs: Dict[str, Any] ) -> tf.Tensor: """ - Sample-based propagation of both inducing points and function values. + Sample-based propagation of both inducing points and function values. See Ober & Aitchison + (2021) for details. """ mean_function = self.mean_function(inputs) @@ -291,6 +293,19 @@ def sample_conditional( inputs: Optional[TensorType] = None, full_cov: bool = False, ) -> tf.Tensor: + """ + Samples function values f based off samples of u. + + :param u: Samples of the inducing points, shape [S, Lout, M, 1] + :param Kff: The diag of the kernel evaluated at input function values, shape [S, N] + :param Kuf: The kernel evaluated between inducing locations and input function values, shape + [S, 1, M, N] + :param chol_Kuu: Cholesky factor of kernel evaluated for inducing points, shape [S, 1, M, M] + :param inputs: Input data points, required for full_cov = True, shape [S, N, Lin] + :param full_cov: Whether to use the full covariance predictive, which gives consistent + samples if true + :return: samples of f, shape [S, M, Lout] + """ Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) @@ -320,10 +335,14 @@ def prior_kl( chol_lKlpI: tf.Tensor, u: tf.Tensor, ) -> tf.Tensor: - r""" - Returns the KL divergence ``KL[q(u)∥p(u)]`` from the prior ``p(u)`` to - the variational distribution ``q(u)``. If this layer uses the - :attr:`whiten`\ ed representation, returns ``KL[q(v)∥p(v)]``. + """ + Returns sample-based estimates of the KL divergence between the approximate posterior and + the prior, KL(q(u)||p(u)). + + :param LT: transpose of L, shape [Lout, M, M] + :param chol_lKlpI: Cholesky of LT @ Kuu @ L + I, shape [S, Lout, M, M] + :param u: Samples of the inducing points, shape [S, Lout, M, 1] + :return: Samples-based estimate of the KL, shape [] """ lv = LT @ self.v @@ -336,7 +355,12 @@ def prior_kl( def sample(self, inputs: TensorType) -> tf.Tensor: """ - .. todo:: TODO: Document this. + Sample consistent functions from the layer. + + TODO: Note that this not follow the behavior of the :class:`Sample` in the rest of GPflux. + + :param inputs: [..., Lin] + :return: consistent samples, shape [..., Lout] """ return self.call(inputs, kwargs={"training": None, "full_cov": True}) From fa064f13b8a308455fc22b247d2dfe28fd5b3d5a Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Tue, 14 Sep 2021 16:12:21 +0100 Subject: [PATCH 09/19] Improved documentation of how KL is computed. --- gpflux/layers/gi_gp_layer.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index a6093740..d45c7a1a 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -69,7 +69,11 @@ class GIGPLayer(tf.keras.layers.Layer): v: Parameter r""" - The pseudo-targets. + The pseudo-targets. Note that this does not have the same meaning as in much of the GP + literature, where it represents a whitened version of the inducing variables. While we do use + whitened representations to compute the KL, we maintain the use of `u` throughout for the + inducing variables, leaving `v` for the pseudo-targets, which follows the notation of Ober & + Aitchison (2021). """ L_loc: Parameter @@ -337,7 +341,21 @@ def prior_kl( ) -> tf.Tensor: """ Returns sample-based estimates of the KL divergence between the approximate posterior and - the prior, KL(q(u)||p(u)). + the prior, KL(q(u)||p(u)). Note that we use a whitened representation to compute the KL: + + P = L LT + u = N(0, K) + v | u = N(u, P^{-1}) + u | v = N(S P u, S) - this is the approximate posterior form for u, where + S = (K^{-1} + P)^{-1} = K - K L (LT K L + I)^{-1} LT K + + To compute the KL: + lu = LT u + lv = LT v + lv | u = N(lu, I) + lv = N(0, LT K L + I) + + P(u)/P(u|lv) = P(lv)/P(lv|u) :param LT: transpose of L, shape [Lout, M, M] :param chol_lKlpI: Cholesky of LT @ Kuu @ L + I, shape [S, Lout, M, M] From 0079d10ba0f5129e2ca54ad298add5bd95034377 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Wed, 15 Sep 2021 14:40:36 +0100 Subject: [PATCH 10/19] Fixed likelihood layer to allow Keras compatibility --- gpflux/layers/likelihood_layer.py | 61 +++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 7 deletions(-) diff --git a/gpflux/layers/likelihood_layer.py b/gpflux/layers/likelihood_layer.py index ef6ef69d..0c6668c2 100644 --- a/gpflux/layers/likelihood_layer.py +++ b/gpflux/layers/likelihood_layer.py @@ -17,7 +17,7 @@ A Keras Layer that wraps a likelihood, while containing the necessary operations for training. """ -from typing import Optional +from typing import Optional, List, Dict, Any import tensorflow as tf import tensorflow_probability as tfp @@ -31,9 +31,20 @@ from gpflux.layers.trackable_layer import TrackableLayer -class SampleBasedGaussianLikelihoodLayer(TrackableLayer): +class SampleBasedGaussianLikelihoodLayer(tfp.layers.DistributionLambda): + """ + A `DistributionLambda` layer that provides support for sample-based Gaussian likelihoods, + for instance with the global inducing posterior. It creates a :class:`tfp.distributions.Normal` + centered at the sample-based predictions with scale equal to the (learnable) noise standard + deviation. For :meth:`convert_to_tensor_fn`, it simply returns the location parameter of the + normal distribution, which in this case is the output samples. + """ def __init__(self, variance: float = 1., variance_lower_bound=1e-8): - super().__init__(dtype=default_float()) + super().__init__( + make_distribution_fn=self._make_distribution_fn, + convert_to_tensor_fn=self._convert_to_tensor_fn, + dtype=default_float() + ) if variance <= variance_lower_bound: raise ValueError( @@ -47,11 +58,25 @@ def call( self, inputs: TensorType, targets: Optional[TensorType] = None, - training: bool = None, + *args: List[Any], + **kwargs: Dict[str, Any], ) -> tfp.distributions.Normal: - likelihood_dist = tfp.distributions.Normal(inputs, tf.sqrt(self.variance)) + """ + The default behaviour upon calling this layer. - if training: + This method calls the `tfp.layers.DistributionLambda` super-class `call` method, which + constructs a `tfp.distributions.Distribution` for the output distributions at the input + points (see :meth:`_make_distribution_fn`). + You can pass this distribution to `tf.convert_to_tensor`, which will return the location of + the distribution (see :meth:`_convert_to_tensor_fn`). + + This method also adds a layer-specific loss function, giving the expected log likelihood + of the model. + """ + outputs = super().call(inputs, *args, **kwargs) + likelihood_dist = outputs[0] + + if kwargs.get("training"): assert targets is not None loss_per_datapoint = tf.reduce_mean( -likelihood_dist.log_prob(tf.expand_dims(targets, 0)) @@ -61,7 +86,29 @@ def call( self.add_loss(loss_per_datapoint) - return likelihood_dist + return outputs + + def _make_distribution_fn(self, inputs: TensorType) -> tfp.distributions.Distribution: + """ + Construct a :class:`tfp.distributions.Normal` instance with mean at the inputs to the layer + (which will correspond to samples from the function posterior) and scale equal to the + noise standard deviation. + + :param inputs: The inputs to the layer, which should be samples from the predictive + posterior of the latent function values. + """ + return tfp.distributions.Normal(inputs, tf.sqrt(self.variance)) + + @staticmethod + def _convert_to_tensor_fn(distribution: tfp.distributions.Distribution) -> tf.Tensor: + """ + Convert the sample-based predictive posterior distribution to samples of the latent + function, which is simply the mean of the distribution. + """ + if not isinstance(distribution, tfp.distributions.Normal): + raise ValueError("Distribution must be an instance of `tfp.distributions.Normal`") + + return distribution.loc class LikelihoodLayer(TrackableLayer): From aded07ae2b3b6f0ab058bcde65c5c9ea640b8296 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Wed, 15 Sep 2021 14:41:28 +0100 Subject: [PATCH 11/19] Some cleaning --- gpflux/layers/gi_gp_layer.py | 2 +- gpflux/models/gi_deep_gp.py | 31 ++++++++++--------------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index d45c7a1a..cafbf7cd 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -92,11 +92,11 @@ def __init__( num_latent_gps: int, num_data: int, num_inducing: int, + *, inducing_targets: Optional[tf.Tensor] = None, prec_init: Optional[float] = 1., mean_function: Optional[MeanFunction] = None, kernel_variance_init: Optional[float] = 1., - *, name: Optional[str] = None, verbose: bool = True, ): diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index a7fe1a0a..5095fa94 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -13,21 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # -""" This module provides an implementation of global inducing points for deep GPs. """ +""" +This module provides an implementation of global inducing points for deep GPs. See Ober & +Aitchison (2021) https://arxiv.org/abs/2005.08140 + +Note: this is a sample-based implementation, so we return samples from the DGP instead of predictive +means and variances. +""" import itertools -from typing import List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Type import tensorflow as tf import tensorflow_probability as tfp -import gpflow from gpflow import Parameter, default_float from gpflow.base import Module, TensorType import gpflux -from gpflux.layers import LayerWithObservations, LikelihoodLayer, SampleBasedGaussianLikelihoodLayer -from gpflux.sampling.sample import Sample +from gpflux.layers import LayerWithObservations, SampleBasedGaussianLikelihoodLayer class GIDeepGP(Module): @@ -35,7 +39,7 @@ class GIDeepGP(Module): f_layers: List[tf.keras.layers.Layer] """ A list of all layers in this DeepGP (just :attr:`likelihood_layer` is separate). """ - likelihood_layer: gpflux.layers.LikelihoodLayer + likelihood_layer: gpflux.layers.SampleBasedGaussianLikelihoodLayer """ The likelihood layer. """ """ @@ -297,18 +301,3 @@ def sample(self, inputs: TensorType, num_samples: int, consistent: bool = False) features = layer(features, training=None, full_cov=consistent) return self._inducing_remove(features) - - -def sample_dgp(model: GIDeepGP) -> Sample: # TODO: should this be part of a [Vanilla]DeepGP class? - function_draws = [layer.sample() for layer in model.f_layers] - # TODO: error check that all layers implement .sample()? - - class ChainedSample(Sample): - """ This class chains samples from consecutive layers. """ - - def __call__(self, X: TensorType) -> tf.Tensor: - for f in function_draws: - X = f(X) - return X - - return ChainedSample() From df6b33bf0592b470daa95f02685f0a8642a3bba6 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Wed, 15 Sep 2021 15:51:21 +0100 Subject: [PATCH 12/19] More cleaning, documentation --- gpflux/layers/gi_gp_layer.py | 12 --- gpflux/models/gi_deep_gp.py | 143 +++++++++++++++++++++++++++++++---- 2 files changed, 129 insertions(+), 26 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index cafbf7cd..1342e545 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -370,15 +370,3 @@ def prior_kl( logpq = logP - logQ return -tf.reduce_mean(logpq) - - def sample(self, inputs: TensorType) -> tf.Tensor: - """ - Sample consistent functions from the layer. - - TODO: Note that this not follow the behavior of the :class:`Sample` in the rest of GPflux. - - :param inputs: [..., Lin] - :return: consistent samples, shape [..., Lout] - """ - - return self.call(inputs, kwargs={"training": None, "full_cov": True}) diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index 5095fa94..aedc1b4a 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -35,13 +35,30 @@ class GIDeepGP(Module): + """ + This class combines a sequential function model ``f(x) = fₙ(⋯ (f₂(f₁(x))))`` + and a likelihood ``p(y|f)``. The model uses the inference method described in Ober & Aitchison + (2021). + + Layers currently do NOT support inheriting from :class:`~gpflux.layers.LayerWithObservations`. + + .. note:: This class is **not** a `tf.keras.Model` subclass itself. To access + Keras features, call either :meth:`as_training_model` or :meth:`as_prediction_model` + (depending on the use-case) to create a `tf.keras.Model` instance. See the method docstrings + for more details. + """ + + inputs: tf.keras.Input + targets: tf.keras.Input f_layers: List[tf.keras.layers.Layer] """ A list of all layers in this DeepGP (just :attr:`likelihood_layer` is separate). """ likelihood_layer: gpflux.layers.SampleBasedGaussianLikelihoodLayer - """ The likelihood layer. """ + """ The likelihood layer. Currently restricted to + :class:`gpflux.layers.SampleBasedGaussianLikelihoodLayer`.""" + default_model_class: Type[tf.keras.Model] """ The default for the *model_class* argument of :meth:`as_training_model` and :meth:`as_prediction_model`. This must have the same semantics as `tf.keras.Model`, @@ -57,6 +74,28 @@ class GIDeepGP(Module): lower bound (:meth:`elbo`). """ + num_train_samples: int + """ + The number of samples from the posterior used for training. Usually lower than the number used + for testing. Cannot be changed during training once the model has been compiled. + """ + + num_test_samples: int + """ + The number of samples from the posterior used for testing. Usually greater than the number used + for training. + """ + + num_inducing: int + """ + The number of inducing points to be used. This is the same throughout the model. + """ + + inducing_data: Parameter + """ + The inducing inputs for the model, propagated through the layers. + """ + def __init__( self, f_layers: List[tf.keras.layers.Layer], @@ -75,9 +114,18 @@ def __init__( """ :param f_layers: The layers ``[f₁, f₂, …, fₙ]`` describing the latent function ``f(x) = fₙ(⋯ (f₂(f₁(x))))``. - :param likelihood: The layer for the likelihood ``p(y|f)``. If this is a - GPflow likelihood, it will be wrapped in a :class:`~gpflux.layers.LikelihoodLayer`. - Alternatively, you can provide a :class:`~gpflux.layers.LikelihoodLayer` explicitly. + :param num_inducing: The number of inducing points used in the approximate posterior. This + must be the same for all layers of the model. If you do not specify a value for this + parameter explicitly, it is automatically detected from the + :attr:`~gpflux.layers.GIGPLayer.num_inducing` attribute in the GP layers. + :param likelihood_var: The variance for the Gaussian likelihood ``p(y|f)``. This will be + passed to a :class:`gpflux.layers.SampleBasedGaussianLikelihoodLayer` instance, which + will be the likelihood. + :param inducing_init: An initialization for the inducing inputs. Must have shape + [num_inducing, input_dim], and cannot be provided along with `inducing_shape`. + :param inducing_shape: The initialization shape of the inducing inputs. Used to initialize + the inputs from N(0, 1) if `inducing_init` is not provided. Must be [num_inducing, + input_dim]. Cannot be provided along with `inducing_init`. :param input_dim: The input dimensionality. :param target_dim: The target dimensionality. :param default_model_class: The default for the *model_class* argument of @@ -86,29 +134,68 @@ def __init__( :param num_data: The number of points in the training dataset; see the :attr:`num_data` attribute. If you do not specify a value for this parameter explicitly, it is automatically - detected from the :attr:`~gpflux.layers.GPLayer.num_data` attribute in the GP layers. + detected from the :attr:`~gpflux.layers.GIGPLayer.num_data` attribute in the GP layers. + :param num_train_samples: The number of samples from the posterior used for training. + :param num_test_samples: The number of samples from the posterior used for testing. """ self.inputs = tf.keras.Input((input_dim,), name="inputs") self.targets = tf.keras.Input((target_dim,), name="targets") self.f_layers = f_layers self.likelihood_layer = SampleBasedGaussianLikelihoodLayer(variance=likelihood_var) - self.num_inducing = num_inducing + self.num_inducing = self._validate_num_inducing(f_layers, num_inducing) self.default_model_class = default_model_class self.num_data = self._validate_num_data(f_layers, num_data) - assert (inducing_init is not None) != (inducing_shape is not None) + if (inducing_init is not None) != (inducing_shape is not None): + raise ValueError(f"One of `inducing_init` or `inducing_shape` must be exclusively" + f"provided.") if inducing_init is None: self.inducing_data = Parameter(tf.random.normal(inducing_shape), dtype=default_float()) else: self.inducing_data = Parameter(inducing_init, dtype=default_float()) - assert num_inducing == tf.shape(self.inducing_data)[0] + if num_inducing != tf.shape(self.inducing_data)[0]: + raise ValueError(f"The number of inducing inputs {self.inducing_data.shape[0]} must " + f"equal num_inducing {num_inducing}.") + if input_dim is not None and input_dim != tf.shape(self.inducing_data)[-1]: + raise ValueError(f"The dimension of the inducing inputs {self.inducing_data.shape[-1]}" + f"must equal input_dim {input_dim}") + self.rank = 1 + len(tf.shape(self.inducing_data)) + if self.rank != 3: + raise ValueError(f"Currently the model only supports data of rank 2 ([N, D]); received" + f"rank {len(self.inducing_data.shape)} instead.") + self.num_train_samples = num_train_samples self.num_test_samples = num_test_samples + @staticmethod + def _validate_num_inducing( + f_layers: List[tf.keras.layers.Layer], num_inducing: Optional[int] = None + ) -> int: + """ + Check that the :attr:`~gpflux.layers.gp_layer.GPLayer.num_inducing` + attributes of all layers in *f_layers* are consistent with each other + and with the (optional) *num_inducing* argument. + + :returns: The validated number of inducing points. + """ + for i, layer in enumerate(f_layers): + layer_num_inducing = getattr(layer, "num_inducing", None) + if num_inducing is None: + num_inducing = layer_num_inducing + else: + if layer_num_inducing is not None and num_inducing != layer_num_inducing: + raise ValueError( + f"f_layers[{i}].num_inducing is inconsistent with num_inducing=" + f"{num_inducing}" + ) + if num_inducing is None: + raise ValueError("Could not determine num_inducing; please provide explicitly") + return num_inducing + @staticmethod def _validate_num_data( f_layers: List[tf.keras.layers.Layer], num_data: Optional[int] = None @@ -134,6 +221,12 @@ def _validate_num_data( return num_data def _inducing_add(self, inputs: TensorType): + """ + Adds the inducing points to the data to propagate through the model. + + :param inputs: input data + :return: concatenated inducing points and datapoints + """ inducing_data = tf.tile( tf.expand_dims(self.inducing_data, 0), @@ -144,6 +237,12 @@ def _inducing_add(self, inputs: TensorType): return x def _inducing_remove(self, inputs: TensorType): + """ + Removes the inducing points from the combined inducing points and data tensor + + :param inputs: combined inducing point and data tensor + :return: data only + """ return inputs[:, self.num_inducing:] def _evaluate_deep_gp( @@ -153,12 +252,13 @@ def _evaluate_deep_gp( training: Optional[bool] = None, ) -> tf.Tensor: """ - Evaluate ``f(x) = fₙ(⋯ (f₂(f₁(x))))`` on the *inputs* argument. + Evaluate ``f(x) = fₙ(⋯ (f₂(f₁(x))))`` on the *inputs* argument. We must start by expanding + the data by copying it :attr:`self.num_train_samples` times, then adding inducing points. + These are then removed after the inducing points and data have been propagated through the + model. Layers that inherit from :class:`~gpflux.layers.LayerWithObservations` - are passed the additional keyword argument ``observations=[inputs, - targets]`` if *targets* contains a value, or ``observations=None`` when - *targets* is `None`. + are not yet supported. """ features = inputs @@ -166,11 +266,9 @@ def _evaluate_deep_gp( # symbolic graph needs to be constructed at "build" time (before either # fit() or predict() get called). if targets is not None: - observations = [inputs, targets] num_samples = self.num_train_samples else: # TODO would it be better to simply pass [inputs, None] in this case? - observations = None num_samples = self.num_test_samples features = tf.tile(tf.expand_dims(features, 0), @@ -290,6 +388,23 @@ def as_prediction_model( return model_class(self.inputs, outputs) def sample(self, inputs: TensorType, num_samples: int, consistent: bool = False) -> tf.Tensor: + """ + Sample `num_samples` from the posterior at `inputs`. If `consistent` is True, we return + consistent function samples. + + :param inputs: The input data at which we wish to obtain samples. Must have shape [N, D]. + :param num_samples: The number of samples to obtain. + :param consistent: Whether to sample consistent samples. + :return: samples with shape [S, N, L]. + """ + if inputs.shape[-1] != self.inducing_data.shape[-1]: + raise ValueError(f"The trailing dimension of `inputs` must match the trailing dimension" + f"the model's inducing data: received {inputs.shape[-1]} but expected" + f"{self.inducing_data.shape[-1]}.") + if len(inputs.shape) != 2: + raise ValueError(f"Currently only inputs of rank 2 are supported; received rank" + f"{len(inputs.shape)}") + features = tf.tile(tf.expand_dims(inputs, 0), [num_samples, *[1]*(self.rank-1)]) features = self._inducing_add(features) From e0eca46e1f4c860a742ac986e91026550291030a Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Wed, 15 Sep 2021 16:18:54 +0100 Subject: [PATCH 13/19] Minor bug fix --- gpflux/models/gi_deep_gp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpflux/models/gi_deep_gp.py b/gpflux/models/gi_deep_gp.py index aedc1b4a..2f90a946 100644 --- a/gpflux/models/gi_deep_gp.py +++ b/gpflux/models/gi_deep_gp.py @@ -146,8 +146,8 @@ def __init__( self.default_model_class = default_model_class self.num_data = self._validate_num_data(f_layers, num_data) - if (inducing_init is not None) != (inducing_shape is not None): - raise ValueError(f"One of `inducing_init` or `inducing_shape` must be exclusively" + if (inducing_init is not None) == (inducing_shape is not None): + raise ValueError(f"One of `inducing_init` or `inducing_shape` must be exclusively " f"provided.") if inducing_init is None: @@ -158,7 +158,7 @@ def __init__( if num_inducing != tf.shape(self.inducing_data)[0]: raise ValueError(f"The number of inducing inputs {self.inducing_data.shape[0]} must " f"equal num_inducing {num_inducing}.") - if input_dim is not None and input_dim != tf.shape(self.inducing_data)[-1]: + if input_dim is not None and tf.shape(self.inducing_data)[-1] != input_dim: raise ValueError(f"The dimension of the inducing inputs {self.inducing_data.shape[-1]}" f"must equal input_dim {input_dim}") From 35f8ab6f35e55ad4d8b6d5d75851dd7f49f20c13 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Thu, 16 Sep 2021 11:07:11 +0100 Subject: [PATCH 14/19] Fixed negative KLs --- gpflux/layers/gi_gp_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 1342e545..269b72dc 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -205,8 +205,8 @@ def mvnormal_log_prob(self, sigma_L: TensorType, X: TensorType) -> tf.Tensor: :param X: evaluation point for log_prob, shape [..., M, D, 1] :return: the log probability, shape [..., M] """ - in_features = self.input_dim - out_features = self.num_latent_gps + in_features = tf.cast(tf.shape(X)[-2], dtype=default_float()) + out_features = tf.cast(tf.shape(X)[-1], dtype=default_float()) trace_quad = tf.reduce_sum(tf.linalg.triangular_solve(sigma_L, X)**2, [-1, -2]) logdet_term = 2.0*tf.reduce_sum(tf.math.log(tf.linalg.diag_part(sigma_L)), -1) return -0.5*trace_quad - 0.5*out_features*(logdet_term + in_features*math.log(2*math.pi)) From f2da0b657c2cea5ef8347415a885413ebb42fa6d Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Mon, 27 Sep 2021 16:08:22 +0100 Subject: [PATCH 15/19] Removed redundant K_r2 --- gpflux/layers/gi_gp_layer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 269b72dc..154c9ec2 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -39,9 +39,6 @@ class BatchingSquaredExponential(SquaredExponential): shape [..., N, D], and X2 with shape [..., M, D], we return [..., N, M] instead of the current behavior, which returns [..., N, ..., M]""" - def K_r2(self, r2): - return self.variance * tf.exp(-0.5 * r2) - def scaled_squared_euclid_dist(self, X, X2=None): X = self.scale(X) X2 = self.scale(X2) From 7f11c341c33e5c7401027bdcd698660a11948ab5 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Mon, 27 Sep 2021 17:50:09 +0100 Subject: [PATCH 16/19] Faster kernel --- gpflux/layers/gi_gp_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 154c9ec2..3f480ff7 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -48,7 +48,7 @@ def scaled_squared_euclid_dist(self, X, X2=None): Xs = tf.reduce_sum((X**2), -1)[..., :, None] X2s = tf.reduce_sum((X2**2), -1)[..., None, :] - return Xs + X2s - 2*X@tf.linalg.adjoint(X2) + return Xs + X2s - 2 * tf.linalg.matmul(X, X2, transpose_b=True) class GIGPLayer(tf.keras.layers.Layer): From 5354255c4d881cabd70c3e595df501f008e8eaa3 Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Thu, 14 Oct 2021 14:26:06 +0100 Subject: [PATCH 17/19] Refactor for GI layers to have predict function --- gpflux/layers/gi_gp_layer.py | 39 ++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 3f480ff7..992558de 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -20,7 +20,7 @@ """ import warnings -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import numpy as np import tensorflow as tf @@ -91,7 +91,7 @@ def __init__( num_inducing: int, *, inducing_targets: Optional[tf.Tensor] = None, - prec_init: Optional[float] = 1., + prec_init: Optional[float] = 10., mean_function: Optional[MeanFunction] = None, kernel_variance_init: Optional[float] = 1., name: Optional[str] = None, @@ -285,6 +285,30 @@ def call( return all_samples + mean_function + def predict( + self, + u: TensorType, + Kff: TensorType, + Kuf: TensorType, + chol_Kuu: TensorType, + inputs: Optional[TensorType] = None, + full_cov: bool = False + ) -> Tuple[tf.Tensor, tf.Tensor]: + Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) + Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) + + if full_cov: + assert inputs is not None + Kff = self.kernel(inputs[..., self.num_inducing:, :]) + Vf = Kff - tf.squeeze(Kfu_invKuu @ Kuf, 1) + Vf = Vf + default_jitter()*tf.eye(tf.shape(Vf)[-1], dtype=default_float()) + else: + Vf = Kff - tf.squeeze(tf.reduce_sum((Kfu_invKuu*tf.linalg.adjoint(Kuf)), -1), 1) + + Vf = Vf[..., None] + + return Ef, Vf + def sample_conditional( self, u: TensorType, @@ -307,8 +331,7 @@ def sample_conditional( samples if true :return: samples of f, shape [S, M, Lout] """ - Kfu_invKuu = tf.linalg.adjoint(tf.linalg.cholesky_solve(chol_Kuu, Kuf)) - Ef = tf.linalg.adjoint(tf.squeeze((Kfu_invKuu @ u), -1)) + Ef, Vf = self.predict(u, Kff, Kuf, chol_Kuu, inputs=inputs, full_cov=full_cov) eps_f = tf.random.normal( tf.shape(Ef), @@ -316,17 +339,11 @@ def sample_conditional( ) if full_cov: - assert inputs is not None - Kff = self.kernel(inputs[..., self.num_inducing:, :]) - Vf = Kff - tf.squeeze(Kfu_invKuu @ Kuf, 1) - Vf = Vf + default_jitter()*tf.eye(tf.shape(Vf)[-1], dtype=default_float()) chol_Vf = tf.linalg.cholesky(Vf) var_part = chol_Vf @ eps_f else: - Vf = Kff - tf.squeeze(tf.reduce_sum((Kfu_invKuu*tf.linalg.adjoint(Kuf)), -1), 1) - - var_part = tf.math.sqrt(Vf)[..., None]*eps_f + var_part = tf.math.sqrt(Vf)*eps_f return Ef + var_part From 5340744fd7d37c90af0dac28853b3d05ef553c6c Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Thu, 14 Oct 2021 15:06:56 +0100 Subject: [PATCH 18/19] Refactor to expose sample_u --- gpflux/layers/gi_gp_layer.py | 63 +++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index 992558de..c71ed6a4 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -229,6 +229,42 @@ def call( S = tf.shape(Kuu)[0] + u, chol_lKlpI, chol_Kuu = self.sample_u(S, Kuu) + + if kwargs.get("training"): + loss_per_datapoint = self.prior_kl(tf.linalg.adjoint(self.L), chol_lKlpI, u) / self.num_data + else: + # TF quirk: add_loss must always add a tensor to compile + loss_per_datapoint = tf.constant(0.0, dtype=default_float()) + self.add_loss(loss_per_datapoint) + + # Metric names should be unique; otherwise they get overwritten if you + # have multiple with the same name + name = f"{self.name}_prior_kl" if self.name else "prior_kl" + self.add_metric(loss_per_datapoint, name=name, aggregation="mean") + + if kwargs.get("full_cov"): + f_samples = self.sample_conditional(u, Kff, Kuf, chol_Kuu, inputs=inputs, full_cov=True) + else: + f_samples = self.sample_conditional(u, Kff, Kuf, chol_Kuu) + + all_samples = tf.concat( + [ + tf.linalg.adjoint(tf.squeeze(u, -1)), + f_samples, + ], + axis=-2 + ) + + return all_samples + mean_function + + def sample_u( + self, + S: int, + Kuu: TensorType + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # Samples inducing locations u + Iuu = tf.eye(self.num_inducing, dtype=default_float()) L = self.L @@ -258,32 +294,7 @@ def call( prec_noise = inv_Kuu_noise + L_noise u = Sigma @ ((L @ LT) @ self.v + prec_noise) - if kwargs.get("training"): - loss_per_datapoint = self.prior_kl(LT, chol_lKlpI, u) / self.num_data - else: - # TF quirk: add_loss must always add a tensor to compile - loss_per_datapoint = tf.constant(0.0, dtype=default_float()) - self.add_loss(loss_per_datapoint) - - # Metric names should be unique; otherwise they get overwritten if you - # have multiple with the same name - name = f"{self.name}_prior_kl" if self.name else "prior_kl" - self.add_metric(loss_per_datapoint, name=name, aggregation="mean") - - if kwargs.get("full_cov"): - f_samples = self.sample_conditional(u, Kff, Kuf, chol_Kuu, inputs=inputs, full_cov=True) - else: - f_samples = self.sample_conditional(u, Kff, Kuf, chol_Kuu) - - all_samples = tf.concat( - [ - tf.linalg.adjoint(tf.squeeze(u, -1)), - f_samples, - ], - axis=-2 - ) - - return all_samples + mean_function + return u, chol_lKlpI, chol_Kuu def predict( self, From efbdff3b5bda52032f91f8dd0c768b4ce2efeb3f Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Fri, 15 Oct 2021 14:45:12 +0100 Subject: [PATCH 19/19] Refactor --- gpflux/layers/gi_gp_layer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gpflux/layers/gi_gp_layer.py b/gpflux/layers/gi_gp_layer.py index c71ed6a4..d63a18d2 100644 --- a/gpflux/layers/gi_gp_layer.py +++ b/gpflux/layers/gi_gp_layer.py @@ -177,7 +177,7 @@ def __init__( ) # [num_latent_gps, num_inducing, num_inducing] self.L_scale = Parameter( - tf.sqrt(self.kernel.variance)*np.sqrt(prec_init)*np.ones((self.num_latent_gps, 1, 1)), + np.sqrt(self.kernel.variance.numpy()*prec_init)*np.ones((self.num_latent_gps, 1, 1)), transform=positive(), dtype=default_float(), name=f"{self.name}_L_scale" if self.name else "L_scale" @@ -227,9 +227,7 @@ def call( Kuu, Kuf, Kfu = tf.expand_dims(Kuu, 1), tf.expand_dims(Kuf, 1), tf.expand_dims(Kfu, 1) - S = tf.shape(Kuu)[0] - - u, chol_lKlpI, chol_Kuu = self.sample_u(S, Kuu) + u, chol_lKlpI, chol_Kuu = self.sample_u(Kuu) if kwargs.get("training"): loss_per_datapoint = self.prior_kl(tf.linalg.adjoint(self.L), chol_lKlpI, u) / self.num_data @@ -260,11 +258,12 @@ def call( def sample_u( self, - S: int, Kuu: TensorType ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: # Samples inducing locations u + S = tf.shape(Kuu)[0] + Iuu = tf.eye(self.num_inducing, dtype=default_float()) L = self.L