diff --git a/trieste/models/gpflux/sampler.py b/trieste/models/gpflux/sampler.py index 4aef9937a..b724beaf1 100644 --- a/trieste/models/gpflux/sampler.py +++ b/trieste/models/gpflux/sampler.py @@ -15,6 +15,7 @@ from __future__ import annotations from abc import ABC +from itertools import cycle from typing import Callable, cast import gpflow.kernels @@ -439,19 +440,24 @@ def __init__(self, layer: GPLayer, n_components: int): dummy_X = inducing_points[0:1, :] self.__call__(dummy_X) - self.b: TensorType = tf.Variable(self.b) - self.W: TensorType = tf.Variable(self.W) def resample(self) -> None: """ Resample weights and biases. """ - if not hasattr(self, "_bias_init"): - self.b.assign(self._sample_bias(tf.shape(self.b), dtype=self._dtype)) - self.W.assign(self._sample_weights(tf.shape(self.W), dtype=self._dtype)) - else: + if isinstance(self.b, tf.Variable): self.b.assign(self._bias_init(tf.shape(self.b), dtype=self._dtype)) - self.W.assign(self._weights_init(tf.shape(self.W), dtype=self._dtype)) + else: + tf.debugging.Assert(isinstance(self.b, list), []) + for b in self.b: + b.assign(self._bias_init(tf.shape(b), dtype=self._dtype)) + + if isinstance(self.W, tf.Variable): + self.W.assign(self._weights_init(self.kernel)(tf.shape(self.W), dtype=self._dtype)) + else: + tf.debugging.Assert(isinstance(self.W, list), []) + for W, k in zip(self.W, cycle(self.sub_kernels)): + W.assign(self._weights_init(k)(tf.shape(W), dtype=self._dtype)) def __call__(self, x: TensorType) -> TensorType: # [N, D] -> [N, L + M] or [P, N, L + M] """