diff --git a/gpflux/layers/basis_functions/fourier_features/base.py b/gpflux/layers/basis_functions/fourier_features/base.py index 0efce2b9..73507ae5 100644 --- a/gpflux/layers/basis_functions/fourier_features/base.py +++ b/gpflux/layers/basis_functions/fourier_features/base.py @@ -44,11 +44,17 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M self.kernel = kernel self.n_components = n_components if isinstance(kernel, gpflow.kernels.MultioutputKernel): - self.is_multioutput = True - self.num_latent_gps = kernel.num_latent_gps + self.is_batched = True + self.batch_size = kernel.num_latent_gps + self.sub_kernels = kernel.latent_kernels + elif isinstance(kernel, gpflow.kernels.Combination): + self.is_batched = True + self.batch_size = len(kernel.kernels) + self.sub_kernels = kernel.kernels else: - self.is_multioutput = False - self.num_latent_gps = 1 + self.is_batched = False + self.batch_size = 1 + self.sub_kernels = [] if kwargs.get("input_dim", None): self._input_dim = kwargs["input_dim"] @@ -64,8 +70,8 @@ def call(self, inputs: TensorType) -> tf.Tensor: :return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case. """ - if self.is_multioutput: - X = [tf.divide(inputs, k.lengthscales) for k in self.kernel.latent_kernels] + if self.is_batched: + X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels] X = tf.stack(X, 0) # [1, N, D] or [P, N, D] else: X = tf.divide(inputs, self.kernel.lengthscales) # [N, D] @@ -86,8 +92,8 @@ def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape: tensor_shape = tf.TensorShape(input_shape).with_rank(2) output_dim = self._compute_output_dim(input_shape) trailing_shape = tensor_shape[:-1].concatenate(output_dim) - if self.is_multioutput: - return tf.TensorShape([self.num_latent_gps]).concatenate(trailing_shape) # [P, N, M] + if self.is_batched: + return tf.TensorShape([self.batch_size]).concatenate(trailing_shape) # [P, N, M] else: return trailing_shape # [N, M] diff --git a/gpflux/layers/basis_functions/fourier_features/random/base.py b/gpflux/layers/basis_functions/fourier_features/random/base.py index bc1d17bd..221b3602 100644 --- a/gpflux/layers/basis_functions/fourier_features/random/base.py +++ b/gpflux/layers/basis_functions/fourier_features/random/base.py @@ -47,6 +47,8 @@ gpflow.kernels.SharedIndependent, ) +RFF_SUPPORTED_COMBINED: Tuple[Type[gpflow.kernels.Combination], ...] = (gpflow.kernels.Sum,) + def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType: """ @@ -79,9 +81,15 @@ def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType: class RandomFourierFeaturesBase(FourierFeaturesBase): def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping): - assert isinstance(kernel, (RFF_SUPPORTED_KERNELS, RFF_SUPPORTED_MULTIOUTPUTS)), ( - f"Unsupported Kernel: only the following kernel types are supported: " - f"{[k.__name__ for k in RFF_SUPPORTED_MULTIOUTPUTS + RFF_SUPPORTED_KERNELS]}" + assert isinstance( + kernel, (RFF_SUPPORTED_KERNELS, RFF_SUPPORTED_MULTIOUTPUTS, RFF_SUPPORTED_COMBINED) + ), "Unsupported Kernel: only the following kernel types are supported: {}".format( + [ + k.__name__ + for k in ( + RFF_SUPPORTED_MULTIOUTPUTS + RFF_SUPPORTED_KERNELS + RFF_SUPPORTED_COMBINED + ) + ] ) if isinstance(kernel, RFF_SUPPORTED_MULTIOUTPUTS): for k in kernel.latent_kernels: @@ -90,6 +98,12 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M f"kernel types are supported: " f"{[k.__name__ for k in RFF_SUPPORTED_KERNELS]}" ) + elif isinstance(kernel, RFF_SUPPORTED_COMBINED): + assert all(isinstance(k, RFF_SUPPORTED_KERNELS) for k in kernel.kernels), ( + f"Unsupported Kernel within the combination kernel; only the following" + f"kernel types are supported: " + f"{[k.__name__ for k in RFF_SUPPORTED_KERNELS]}" + ) super(RandomFourierFeaturesBase, self).__init__(kernel, n_components, **kwargs) def build(self, input_shape: ShapeType) -> None: @@ -103,8 +117,8 @@ def build(self, input_shape: ShapeType) -> None: super(RandomFourierFeaturesBase, self).build(input_shape) def _weights_build(self, input_dim: int, n_components: int) -> None: - if self.is_multioutput: - shape = (self.num_latent_gps, n_components, input_dim) # [P, M, D] + if self.is_batched: + shape = (self.batch_size, n_components, input_dim) # [P, M, D] else: shape = (n_components, input_dim) # type: ignore self.W = self.add_weight( @@ -129,16 +143,15 @@ def _weights_init_individual( return _sample_students_t(nu, shape, dtype) def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType: - if self.is_multioutput: + if self.is_batched: if isinstance(self.kernel, gpflow.kernels.SharedIndependent): weights_list = [ - self._weights_init_individual(self.kernel.latent_kernels[0], shape[1:], dtype) - for _ in range(self.num_latent_gps) + self._weights_init_individual(self.sub_kernels[0], shape[1:], dtype) + for _ in range(self.batch_size) ] else: weights_list = [ - self._weights_init_individual(k, shape[1:], dtype) - for k in self.kernel.latent_kernels + self._weights_init_individual(k, shape[1:], dtype) for k in self.sub_kernels ] return tf.stack(weights_list, 0) # [P, M, D] else: @@ -203,10 +216,10 @@ def _compute_constant(self) -> tf.Tensor: :return: A tensor with the shape ``[]`` (i.e. a scalar). """ - if self.is_multioutput: + if self.is_batched: constants = [ self.rff_constant(k.variance, output_dim=2 * self.n_components) - for k in self.kernel.latent_kernels + for k in self.sub_kernels ] return tf.stack(constants, 0)[:, None, None] # [P, 1, 1] else: @@ -253,8 +266,8 @@ def build(self, input_shape: ShapeType) -> None: super(RandomFourierFeaturesCosine, self).build(input_shape) def _bias_build(self, n_components: int) -> None: - if self.is_multioutput: - shape = (self.num_latent_gps, 1, n_components) + if self.is_batched: + shape = (self.batch_size, 1, n_components) else: shape = (1, n_components) # type: ignore self.b = self.add_weight( @@ -285,10 +298,10 @@ def _compute_constant(self) -> tf.Tensor: :return: A tensor with the shape ``[]`` (i.e. a scalar). """ - if self.is_multioutput: + if self.is_batched: constants = [ self.rff_constant(k.variance, output_dim=self.n_components) - for k in self.kernel.latent_kernels + for k in self.sub_kernels ] return tf.stack(constants, 0)[:, None, None] # [1, 1, 1] or [P, 1, 1] else: diff --git a/tests/gpflux/layers/basis_functions/fourier_features/test_random.py b/tests/gpflux/layers/basis_functions/fourier_features/test_random.py index 741b371a..ac1f38a1 100644 --- a/tests/gpflux/layers/basis_functions/fourier_features/test_random.py +++ b/tests/gpflux/layers/basis_functions/fourier_features/test_random.py @@ -62,7 +62,7 @@ def _kernel_cls_fixture(request): @pytest.fixture( - name="multioutput_kernel", + name="multi_kernel", params=[ gpflow.kernels.SharedIndependent(gpflow.kernels.SquaredExponential(), output_dim=3), gpflow.kernels.SeparateIndependent( @@ -71,9 +71,10 @@ def _kernel_cls_fixture(request): gpflow.kernels.Matern32(lengthscales=0.1), ] ), + gpflow.kernels.Sum([gpflow.kernels.SquaredExponential(), gpflow.kernels.Matern52()]), ], ) -def _multioutput_kernel_cls_fixture(request): +def _multi_kernel_cls_fixture(request): return request.param @@ -105,6 +106,7 @@ def _basis_func_cls_fixture(request): kernels=[gpflow.kernels.SquaredExponential(), gpflow.kernels.SquaredExponential()], W=tf.ones([2, 1]), ), + gpflow.kernels.Sum([gpflow.kernels.SquaredExponential(), gpflow.kernels.Constant()]), ], ) def test_throw_for_unsupported_kernel(basis_func_cls, kernel): @@ -138,15 +140,15 @@ def test_random_fourier_features_can_approximate_kernel_multidim( np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) -def test_multioutput_random_fourier_features_can_approximate_kernel_multidim( - random_basis_func_cls, multioutput_kernel, n_dims +def test_multi_random_fourier_features_can_approximate_kernel_multidim( + random_basis_func_cls, multi_kernel, n_dims ): n_components = 40000 x_rows = 20 y_rows = 30 - fourier_features = random_basis_func_cls(multioutput_kernel, n_components, dtype=tf.float64) + fourier_features = random_basis_func_cls(multi_kernel, n_components, dtype=tf.float64) x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64) y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64) @@ -155,7 +157,13 @@ def test_multioutput_random_fourier_features_can_approximate_kernel_multidim( v = fourier_features(y) approx_kernel_matrix = u @ tf.linalg.matrix_transpose(v) - actual_kernel_matrix = multioutput_kernel.K(x, y, full_output_cov=False) + if isinstance(multi_kernel, gpflow.kernels.Sum): + approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0) + + if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel): + actual_kernel_matrix = multi_kernel.K(x, y, full_output_cov=False) + else: + actual_kernel_matrix = multi_kernel.K(x, y) np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) @@ -206,24 +214,30 @@ def test_random_fourier_feature_layer_compute_covariance_of_inducing_variables( np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) -def test_multioutput_random_fourier_feature_layer_compute_covariance_of_inducing_variables( - random_basis_func_cls, multioutput_kernel, batch_size +def test_multi_random_fourier_feature_layer_compute_covariance_of_inducing_variables( + random_basis_func_cls, multi_kernel, batch_size ): """ Ensure that the random fourier feature map can be used to approximate the covariance matrix between the inducing point vectors of a sparse GP, with the condition that the number of latent - GP models is greater than one. This test replicates the above, but for multioutput kernels. + GP models is greater than one. This test replicates the above, but for multi-kernels. """ n_components = 10000 - fourier_features = random_basis_func_cls(multioutput_kernel, n_components, dtype=tf.float64) + fourier_features = random_basis_func_cls(multi_kernel, n_components, dtype=tf.float64) x_new = tf.ones(shape=(2 * batch_size + 1, 1), dtype=tf.float64) u = fourier_features(x_new) approx_kernel_matrix = u @ tf.linalg.matrix_transpose(u) - actual_kernel_matrix = multioutput_kernel.K(x_new, x_new, full_output_cov=False) + if isinstance(multi_kernel, gpflow.kernels.Sum): + approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0) + + if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel): + actual_kernel_matrix = multi_kernel.K(x_new, x_new, full_output_cov=False) + else: + actual_kernel_matrix = multi_kernel.K(x_new, x_new) np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2) @@ -237,11 +251,11 @@ def test_fourier_features_shapes(basis_func_cls, n_components, n_dims, batch_siz np.testing.assert_equal(features.shape, output_shape) -def test_multioutput_fourier_features_shapes( - random_basis_func_cls, multioutput_kernel, n_components, n_dims, batch_size +def test_multi_fourier_features_shapes( + random_basis_func_cls, multi_kernel, n_components, n_dims, batch_size ): input_shape = (batch_size, n_dims) - feature_functions = random_basis_func_cls(multioutput_kernel, n_components, dtype=tf.float64) + feature_functions = random_basis_func_cls(multi_kernel, n_components, dtype=tf.float64) output_shape = feature_functions.compute_output_shape(input_shape) features = feature_functions(tf.ones(shape=input_shape)) np.testing.assert_equal(features.shape, output_shape)