diff --git a/src/probnum/randvars/_normal.py b/src/probnum/randvars/_normal.py index cdc428a57..2281883bc 100644 --- a/src/probnum/randvars/_normal.py +++ b/src/probnum/randvars/_normal.py @@ -452,6 +452,26 @@ def _univariate_entropy(self: ValueType) -> np.float_: scipy.stats.norm.entropy(loc=self.mean, scale=self.std), dtype=np.float_, ) + + def _multivariate_sample( + self, + rng: np.random.Generator, + size: ShapeType = (), + ) -> Union[np.floating, np.ndarray]: + + if self.cov_cholesky is None: + raise ValueError('Cholesky factor of the covariance operator is not available.') + + if self.mean.ndim != 1: + raise ValueError('Mean must be a vector.') + + sample = scipy.stats.norm().rvs( + size=self.shape + _utils.as_shape(size), + random_state=rng, + ) + sample = self.cov_cholesky @ sample + sample += self.mean[..., np.newaxis] + return sample.T # Multi- and matrixvariate Gaussians def dense_cov_cholesky( diff --git a/tests/test_randvars/test_normal.py b/tests/test_randvars/test_normal.py index 3d6f99731..3fd062179 100644 --- a/tests/test_randvars/test_normal.py +++ b/tests/test_randvars/test_normal.py @@ -213,6 +213,23 @@ def test_symmetric_samples(self): ), ) + def test_multivariate_sample_zero_cov(self): + """Draw sample from distribution with zero kernels and check whether it equals the mean.""" + mean = np.random.rand(10) + cov = np.zeros((10, 10)) + rv = randvars.Normal(mean=mean, cov=0*cov, cov_cholesky=0*cov) + rv_sample = rv.sample(rng=self.rng, size=1) + self.assertAllClose(rv.mean, rv_sample) + + def test_multivariate_sample_shape(self): + """Test whether the shape of the sample is correct.""" + N, n_blocks, size = 10, 4, 36 + mean = np.random.rand(n_blocks*N) + cov = cov_sqrt = linops.BlockDiagonalMatrix(*[np.eye(N) for _ in range(n_blocks)]) + rv = randvars.Normal(mean=mean, cov=cov, cov_cholesky=cov_sqrt) + rv_sample = rv._multivariate_sample(rng=self.rng, size=size) + self.assertEqual((size, N*n_blocks), rv_sample.shape) + def test_indexing(self): """Indexing with Python integers yields a univariate normal distribution.""" for mean, cov in self.normal_params: