diff --git a/src/probnum/backend/_dispatcher.py b/src/probnum/backend/_dispatcher.py index 6e916883de..5d10006a3d 100644 --- a/src/probnum/backend/_dispatcher.py +++ b/src/probnum/backend/_dispatcher.py @@ -47,6 +47,10 @@ def torch(self, impl: Callable) -> Callable: return impl def __call__(self, *args, **kwargs): + if BACKEND not in self._impl: + raise NotImplementedError( + f"This function is not implemented for the backend `{BACKEND.name}`" + ) return self._impl[BACKEND](*args, **kwargs) def __get__(self, obj, objtype=None): diff --git a/src/probnum/randvars/_normal.py b/src/probnum/randvars/_normal.py index 9ae8f0c177..d0ace4b1a4 100644 --- a/src/probnum/randvars/_normal.py +++ b/src/probnum/randvars/_normal.py @@ -501,10 +501,10 @@ def _logpdf(self, x: ArrayType) -> ArrayType: return res - def _cdf(self, x: ArrayType) -> ArrayType: - if backend.BACKEND is not backend.Backend.NUMPY: - raise NotImplementedError() + _cdf = backend.Dispatcher() + @_cdf.numpy + def _cdf_numpy(self, x: ArrayType) -> ArrayType: import scipy.stats # pylint: disable=import-outside-toplevel return scipy.stats.multivariate_normal.cdf( @@ -513,10 +513,10 @@ def _cdf(self, x: ArrayType) -> ArrayType: cov=self.dense_cov, ) - def _logcdf(self, x: ArrayType) -> ArrayType: - if backend.BACKEND is not backend.Backend.NUMPY: - raise NotImplementedError() + _logcdf = backend.Dispatcher() + @_logcdf.numpy + def _logcdf_numpy(self, x: ArrayType) -> ArrayType: import scipy.stats # pylint: disable=import-outside-toplevel return scipy.stats.multivariate_normal.logcdf( diff --git a/src/probnum/randvars/_random_variable.py b/src/probnum/randvars/_random_variable.py index 3dff0699f4..cbd1f4f5cc 100644 --- a/src/probnum/randvars/_random_variable.py +++ b/src/probnum/randvars/_random_variable.py @@ -162,7 +162,7 @@ def ndim(self) -> int: def size(self) -> int: """Size of realizations of the random variable, defined as the product over all components of :attr:`shape`.""" - return functools.reduce(operator.mul, self.__shape, initial=1) + return functools.reduce(operator.mul, self.__shape, 1) @property def dtype(self) -> backend.dtype: