Skip to content

Commit

Permalink
Remove as_scalar from utils
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Dec 7, 2021
1 parent 0e1c379 commit 6a2f4ec
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 40 deletions.
18 changes: 18 additions & 0 deletions src/probnum/backend/_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from probnum import backend as _backend
from probnum.typing import ArrayType, DTypeArgType, ScalarArgType

if _backend.BACKEND is _backend.Backend.NUMPY:
from . import _numpy as _core
Expand Down Expand Up @@ -73,3 +74,20 @@
# Just-in-Time Compilation
jit = _core.jit
jit_method = _core.jit_method


def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
"""Convert a scalar into a NumPy scalar.
Parameters
----------
x
Scalar value.
dtype
Data type of the scalar.
"""

if ndim(x) != 0:
raise ValueError("The given input is not a scalar.")

return asarray(x, dtype=dtype)[()]
4 changes: 2 additions & 2 deletions src/probnum/randprocs/kernels/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from probnum import backend, utils
from probnum import backend
from probnum.typing import ArrayType, IntArgType, ScalarArgType

from ._kernel import Kernel
Expand Down Expand Up @@ -39,7 +39,7 @@ class Linear(Kernel):
"""

def __init__(self, input_dim: IntArgType, constant: ScalarArgType = 0.0):
self.constant = utils.as_scalar(constant)
self.constant = backend.as_scalar(constant)
super().__init__(input_dim=input_dim)

@backend.jit_method
Expand Down
6 changes: 3 additions & 3 deletions src/probnum/randprocs/kernels/_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from probnum import backend, utils
from probnum import backend
from probnum.typing import ArrayType, IntArgType, ScalarArgType

from ._kernel import Kernel
Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(
constant: ScalarArgType = 0.0,
exponent: IntArgType = 1.0,
):
self.constant = utils.as_scalar(constant)
self.exponent = utils.as_scalar(exponent)
self.constant = backend.as_scalar(constant)
self.exponent = backend.as_scalar(exponent)
super().__init__(input_dim=input_dim)

@backend.jit_method
Expand Down
8 changes: 3 additions & 5 deletions src/probnum/randprocs/kernels/_rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

from typing import Optional

import numpy as np

from probnum import backend, utils
from probnum import backend
from probnum.typing import ArrayType, IntArgType, ScalarArgType

from ._kernel import IsotropicMixin, Kernel
Expand Down Expand Up @@ -62,8 +60,8 @@ def __init__(
lengthscale: ScalarArgType = 1.0,
alpha: ScalarArgType = 1.0,
):
self.lengthscale = utils.as_scalar(lengthscale)
self.alpha = utils.as_scalar(alpha)
self.lengthscale = backend.as_scalar(lengthscale)
self.alpha = backend.as_scalar(alpha)
if not self.alpha > 0:
raise ValueError(f"Scale mixture alpha={self.alpha} must be positive.")
super().__init__(input_dim=input_dim)
Expand Down
4 changes: 2 additions & 2 deletions src/probnum/randprocs/kernels/_white_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Optional

from probnum import backend, utils
from probnum import backend
from probnum.typing import ArrayType, IntArgType, ScalarArgType

from ._kernel import Kernel
Expand All @@ -25,7 +25,7 @@ class WhiteNoise(Kernel):
"""

def __init__(self, input_dim: IntArgType, sigma: ScalarArgType = 1.0):
self.sigma = utils.as_scalar(sigma)
self.sigma = backend.as_scalar(sigma)
self._sigma_sq = self.sigma ** 2
super().__init__(input_dim=input_dim)

Expand Down
2 changes: 1 addition & 1 deletion src/probnum/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
ScalarArgType = Union[int, float, complex, numbers.Number, np.number]
"""Type of a public API argument for supplying a scalar value. Values of this type
should always be converted into :class:`np.generic` using the function
:func:`probnum.utils.as_scalar` before further internal processing."""
:func:`probnum.backend.as_scalar` before further internal processing."""

LinearOperatorArgType = Union[
np.ndarray,
Expand Down
1 change: 0 additions & 1 deletion src/probnum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
__all__ = [
"as_colvec",
"atleast_1d",
"as_scalar",
"as_numpy_scalar",
"as_shape",
]
28 changes: 2 additions & 26 deletions src/probnum/utils/argutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,9 @@

import numpy as np

from probnum import backend
from probnum.typing import (
ArrayType,
DTypeArgType,
ScalarArgType,
ShapeArgType,
ShapeType,
)
from probnum.typing import DTypeArgType, ScalarArgType, ShapeArgType, ShapeType

__all__ = ["as_shape", "as_numpy_scalar", "as_scalar"]
__all__ = ["as_shape", "as_numpy_scalar"]


def as_shape(x: ShapeArgType, ndim: Optional[numbers.Integral] = None) -> ShapeType:
Expand Down Expand Up @@ -64,20 +57,3 @@ def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic:
raise ValueError("The given input is not a scalar.")

return np.asarray(x, dtype=dtype)[()]


def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
"""Convert a scalar into a NumPy scalar.
Parameters
----------
x
Scalar value.
dtype
Data type of the scalar.
"""

if backend.ndim(x) != 0:
raise ValueError("The given input is not a scalar.")

return backend.asarray(x, dtype=dtype)[()]

0 comments on commit 6a2f4ec

Please sign in to comment.