Skip to content

Commit

Permalink
Refactor Normal._cdf to use backend.Dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Nov 30, 2021
1 parent f9cdf13 commit 3c8203d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 4 additions & 0 deletions src/probnum/backend/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def torch(self, impl: Callable) -> Callable:
return impl

def __call__(self, *args, **kwargs):
if BACKEND not in self._impl:
raise NotImplementedError(
f"This function is not implemented for the backend `{BACKEND.name}`"
)
return self._impl[BACKEND](*args, **kwargs)

def __get__(self, obj, objtype=None):
Expand Down
12 changes: 6 additions & 6 deletions src/probnum/randvars/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,10 @@ def _logpdf(self, x: ArrayType) -> ArrayType:

return res

def _cdf(self, x: ArrayType) -> ArrayType:
if backend.BACKEND is not backend.Backend.NUMPY:
raise NotImplementedError()
_cdf = backend.Dispatcher()

@_cdf.numpy
def _cdf_numpy(self, x: ArrayType) -> ArrayType:
import scipy.stats # pylint: disable=import-outside-toplevel

return scipy.stats.multivariate_normal.cdf(
Expand All @@ -513,10 +513,10 @@ def _cdf(self, x: ArrayType) -> ArrayType:
cov=self.dense_cov,
)

def _logcdf(self, x: ArrayType) -> ArrayType:
if backend.BACKEND is not backend.Backend.NUMPY:
raise NotImplementedError()
_logcdf = backend.Dispatcher()

@_logcdf.numpy
def _logcdf_numpy(self, x: ArrayType) -> ArrayType:
import scipy.stats # pylint: disable=import-outside-toplevel

return scipy.stats.multivariate_normal.logcdf(
Expand Down
2 changes: 1 addition & 1 deletion src/probnum/randvars/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def ndim(self) -> int:
def size(self) -> int:
"""Size of realizations of the random variable, defined as the product over all
components of :attr:`shape`."""
return functools.reduce(operator.mul, self.__shape, initial=1)
return functools.reduce(operator.mul, self.__shape, 1)

@property
def dtype(self) -> backend.dtype:
Expand Down

0 comments on commit 3c8203d

Please sign in to comment.