Skip to content

Commit

Permalink
Add RFF for Sum kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Sep 19, 2024
1 parent b3ad682 commit abbd3cd
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 38 deletions.
22 changes: 14 additions & 8 deletions gpflux/layers/basis_functions/fourier_features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,17 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M
self.kernel = kernel
self.n_components = n_components
if isinstance(kernel, gpflow.kernels.MultioutputKernel):
self.is_multioutput = True
self.num_latent_gps = kernel.num_latent_gps
self.is_batched = True
self.batch_size = kernel.num_latent_gps
self.sub_kernels = kernel.latent_kernels
elif isinstance(kernel, gpflow.kernels.Combination):
self.is_batched = True
self.batch_size = len(kernel.kernels)
self.sub_kernels = kernel.kernels
else:
self.is_multioutput = False
self.num_latent_gps = 1
self.is_batched = False
self.batch_size = 1
self.sub_kernels = []

if kwargs.get("input_dim", None):
self._input_dim = kwargs["input_dim"]
Expand All @@ -64,8 +70,8 @@ def call(self, inputs: TensorType) -> tf.Tensor:
:return: A tensor with the shape ``[N, M]``, or shape ``[P, N, M]'' in the multioutput case.
"""
if self.is_multioutput:
X = [tf.divide(inputs, k.lengthscales) for k in self.kernel.latent_kernels]
if self.is_batched:
X = [tf.divide(inputs, k.lengthscales) for k in self.sub_kernels]
X = tf.stack(X, 0) # [1, N, D] or [P, N, D]
else:
X = tf.divide(inputs, self.kernel.lengthscales) # [N, D]
Expand All @@ -86,8 +92,8 @@ def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape:
tensor_shape = tf.TensorShape(input_shape).with_rank(2)
output_dim = self._compute_output_dim(input_shape)
trailing_shape = tensor_shape[:-1].concatenate(output_dim)
if self.is_multioutput:
return tf.TensorShape([self.num_latent_gps]).concatenate(trailing_shape) # [P, N, M]
if self.is_batched:
return tf.TensorShape([self.batch_size]).concatenate(trailing_shape) # [P, N, M]
else:
return trailing_shape # [N, M]

Expand Down
45 changes: 29 additions & 16 deletions gpflux/layers/basis_functions/fourier_features/random/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
gpflow.kernels.SharedIndependent,
)

RFF_SUPPORTED_COMBINED: Tuple[Type[gpflow.kernels.Combination], ...] = (gpflow.kernels.Sum,)


def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Expand Down Expand Up @@ -79,9 +81,15 @@ def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType:

class RandomFourierFeaturesBase(FourierFeaturesBase):
def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
assert isinstance(kernel, (RFF_SUPPORTED_KERNELS, RFF_SUPPORTED_MULTIOUTPUTS)), (
f"Unsupported Kernel: only the following kernel types are supported: "
f"{[k.__name__ for k in RFF_SUPPORTED_MULTIOUTPUTS + RFF_SUPPORTED_KERNELS]}"
assert isinstance(
kernel, (RFF_SUPPORTED_KERNELS, RFF_SUPPORTED_MULTIOUTPUTS, RFF_SUPPORTED_COMBINED)
), "Unsupported Kernel: only the following kernel types are supported: {}".format(
[
k.__name__
for k in (
RFF_SUPPORTED_MULTIOUTPUTS + RFF_SUPPORTED_KERNELS + RFF_SUPPORTED_COMBINED
)
]
)
if isinstance(kernel, RFF_SUPPORTED_MULTIOUTPUTS):
for k in kernel.latent_kernels:
Expand All @@ -90,6 +98,12 @@ def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: M
f"kernel types are supported: "
f"{[k.__name__ for k in RFF_SUPPORTED_KERNELS]}"
)
elif isinstance(kernel, RFF_SUPPORTED_COMBINED):
assert all(isinstance(k, RFF_SUPPORTED_KERNELS) for k in kernel.kernels), (
f"Unsupported Kernel within the combination kernel; only the following"
f"kernel types are supported: "
f"{[k.__name__ for k in RFF_SUPPORTED_KERNELS]}"
)
super(RandomFourierFeaturesBase, self).__init__(kernel, n_components, **kwargs)

def build(self, input_shape: ShapeType) -> None:
Expand All @@ -103,8 +117,8 @@ def build(self, input_shape: ShapeType) -> None:
super(RandomFourierFeaturesBase, self).build(input_shape)

def _weights_build(self, input_dim: int, n_components: int) -> None:
if self.is_multioutput:
shape = (self.num_latent_gps, n_components, input_dim) # [P, M, D]
if self.is_batched:
shape = (self.batch_size, n_components, input_dim) # [P, M, D]
else:
shape = (n_components, input_dim) # type: ignore
self.W = self.add_weight(
Expand All @@ -129,16 +143,15 @@ def _weights_init_individual(
return _sample_students_t(nu, shape, dtype)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
if self.is_multioutput:
if self.is_batched:
if isinstance(self.kernel, gpflow.kernels.SharedIndependent):
weights_list = [
self._weights_init_individual(self.kernel.latent_kernels[0], shape[1:], dtype)
for _ in range(self.num_latent_gps)
self._weights_init_individual(self.sub_kernels[0], shape[1:], dtype)
for _ in range(self.batch_size)
]
else:
weights_list = [
self._weights_init_individual(k, shape[1:], dtype)
for k in self.kernel.latent_kernels
self._weights_init_individual(k, shape[1:], dtype) for k in self.sub_kernels
]
return tf.stack(weights_list, 0) # [P, M, D]
else:
Expand Down Expand Up @@ -203,10 +216,10 @@ def _compute_constant(self) -> tf.Tensor:
:return: A tensor with the shape ``[]`` (i.e. a scalar).
"""
if self.is_multioutput:
if self.is_batched:
constants = [
self.rff_constant(k.variance, output_dim=2 * self.n_components)
for k in self.kernel.latent_kernels
for k in self.sub_kernels
]
return tf.stack(constants, 0)[:, None, None] # [P, 1, 1]
else:
Expand Down Expand Up @@ -253,8 +266,8 @@ def build(self, input_shape: ShapeType) -> None:
super(RandomFourierFeaturesCosine, self).build(input_shape)

def _bias_build(self, n_components: int) -> None:
if self.is_multioutput:
shape = (self.num_latent_gps, 1, n_components)
if self.is_batched:
shape = (self.batch_size, 1, n_components)
else:
shape = (1, n_components) # type: ignore
self.b = self.add_weight(
Expand Down Expand Up @@ -285,10 +298,10 @@ def _compute_constant(self) -> tf.Tensor:
:return: A tensor with the shape ``[]`` (i.e. a scalar).
"""
if self.is_multioutput:
if self.is_batched:
constants = [
self.rff_constant(k.variance, output_dim=self.n_components)
for k in self.kernel.latent_kernels
for k in self.sub_kernels
]
return tf.stack(constants, 0)[:, None, None] # [1, 1, 1] or [P, 1, 1]
else:
Expand Down
42 changes: 28 additions & 14 deletions tests/gpflux/layers/basis_functions/fourier_features/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _kernel_cls_fixture(request):


@pytest.fixture(
name="multioutput_kernel",
name="multi_kernel",
params=[
gpflow.kernels.SharedIndependent(gpflow.kernels.SquaredExponential(), output_dim=3),
gpflow.kernels.SeparateIndependent(
Expand All @@ -71,9 +71,10 @@ def _kernel_cls_fixture(request):
gpflow.kernels.Matern32(lengthscales=0.1),
]
),
gpflow.kernels.Sum([gpflow.kernels.SquaredExponential(), gpflow.kernels.Matern52()]),
],
)
def _multioutput_kernel_cls_fixture(request):
def _multi_kernel_cls_fixture(request):
return request.param


Expand Down Expand Up @@ -105,6 +106,7 @@ def _basis_func_cls_fixture(request):
kernels=[gpflow.kernels.SquaredExponential(), gpflow.kernels.SquaredExponential()],
W=tf.ones([2, 1]),
),
gpflow.kernels.Sum([gpflow.kernels.SquaredExponential(), gpflow.kernels.Constant()]),
],
)
def test_throw_for_unsupported_kernel(basis_func_cls, kernel):
Expand Down Expand Up @@ -138,15 +140,15 @@ def test_random_fourier_features_can_approximate_kernel_multidim(
np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2)


def test_multioutput_random_fourier_features_can_approximate_kernel_multidim(
random_basis_func_cls, multioutput_kernel, n_dims
def test_multi_random_fourier_features_can_approximate_kernel_multidim(
random_basis_func_cls, multi_kernel, n_dims
):
n_components = 40000

x_rows = 20
y_rows = 30

fourier_features = random_basis_func_cls(multioutput_kernel, n_components, dtype=tf.float64)
fourier_features = random_basis_func_cls(multi_kernel, n_components, dtype=tf.float64)

x = tf.random.uniform((x_rows, n_dims), dtype=tf.float64)
y = tf.random.uniform((y_rows, n_dims), dtype=tf.float64)
Expand All @@ -155,7 +157,13 @@ def test_multioutput_random_fourier_features_can_approximate_kernel_multidim(
v = fourier_features(y)
approx_kernel_matrix = u @ tf.linalg.matrix_transpose(v)

actual_kernel_matrix = multioutput_kernel.K(x, y, full_output_cov=False)
if isinstance(multi_kernel, gpflow.kernels.Sum):
approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0)

if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel):
actual_kernel_matrix = multi_kernel.K(x, y, full_output_cov=False)
else:
actual_kernel_matrix = multi_kernel.K(x, y)

np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2)

Expand Down Expand Up @@ -206,24 +214,30 @@ def test_random_fourier_feature_layer_compute_covariance_of_inducing_variables(
np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2)


def test_multioutput_random_fourier_feature_layer_compute_covariance_of_inducing_variables(
random_basis_func_cls, multioutput_kernel, batch_size
def test_multi_random_fourier_feature_layer_compute_covariance_of_inducing_variables(
random_basis_func_cls, multi_kernel, batch_size
):
"""
Ensure that the random fourier feature map can be used to approximate the covariance matrix
between the inducing point vectors of a sparse GP, with the condition that the number of latent
GP models is greater than one. This test replicates the above, but for multioutput kernels.
GP models is greater than one. This test replicates the above, but for multi-kernels.
"""
n_components = 10000

fourier_features = random_basis_func_cls(multioutput_kernel, n_components, dtype=tf.float64)
fourier_features = random_basis_func_cls(multi_kernel, n_components, dtype=tf.float64)

x_new = tf.ones(shape=(2 * batch_size + 1, 1), dtype=tf.float64)

u = fourier_features(x_new)
approx_kernel_matrix = u @ tf.linalg.matrix_transpose(u)

actual_kernel_matrix = multioutput_kernel.K(x_new, x_new, full_output_cov=False)
if isinstance(multi_kernel, gpflow.kernels.Sum):
approx_kernel_matrix = tf.reduce_sum(approx_kernel_matrix, axis=0)

if isinstance(multi_kernel, gpflow.kernels.MultioutputKernel):
actual_kernel_matrix = multi_kernel.K(x_new, x_new, full_output_cov=False)
else:
actual_kernel_matrix = multi_kernel.K(x_new, x_new)

np.testing.assert_allclose(approx_kernel_matrix, actual_kernel_matrix, atol=5e-2)

Expand All @@ -237,11 +251,11 @@ def test_fourier_features_shapes(basis_func_cls, n_components, n_dims, batch_siz
np.testing.assert_equal(features.shape, output_shape)


def test_multioutput_fourier_features_shapes(
random_basis_func_cls, multioutput_kernel, n_components, n_dims, batch_size
def test_multi_fourier_features_shapes(
random_basis_func_cls, multi_kernel, n_components, n_dims, batch_size
):
input_shape = (batch_size, n_dims)
feature_functions = random_basis_func_cls(multioutput_kernel, n_components, dtype=tf.float64)
feature_functions = random_basis_func_cls(multi_kernel, n_components, dtype=tf.float64)
output_shape = feature_functions.compute_output_shape(input_shape)
features = feature_functions(tf.ones(shape=input_shape))
np.testing.assert_equal(features.shape, output_shape)
Expand Down

0 comments on commit abbd3cd

Please sign in to comment.