Skip to content

Commit 6a2f4ec

Browse files
Remove as_scalar from utils
1 parent 0e1c379 commit 6a2f4ec

File tree

8 files changed

+31
-40
lines changed

8 files changed

+31
-40
lines changed

src/probnum/backend/_core/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from probnum import backend as _backend
2+
from probnum.typing import ArrayType, DTypeArgType, ScalarArgType
23

34
if _backend.BACKEND is _backend.Backend.NUMPY:
45
from . import _numpy as _core
@@ -73,3 +74,20 @@
7374
# Just-in-Time Compilation
7475
jit = _core.jit
7576
jit_method = _core.jit_method
77+
78+
79+
def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
80+
"""Convert a scalar into a NumPy scalar.
81+
82+
Parameters
83+
----------
84+
x
85+
Scalar value.
86+
dtype
87+
Data type of the scalar.
88+
"""
89+
90+
if ndim(x) != 0:
91+
raise ValueError("The given input is not a scalar.")
92+
93+
return asarray(x, dtype=dtype)[()]

src/probnum/randprocs/kernels/_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Optional
44

5-
from probnum import backend, utils
5+
from probnum import backend
66
from probnum.typing import ArrayType, IntArgType, ScalarArgType
77

88
from ._kernel import Kernel
@@ -39,7 +39,7 @@ class Linear(Kernel):
3939
"""
4040

4141
def __init__(self, input_dim: IntArgType, constant: ScalarArgType = 0.0):
42-
self.constant = utils.as_scalar(constant)
42+
self.constant = backend.as_scalar(constant)
4343
super().__init__(input_dim=input_dim)
4444

4545
@backend.jit_method

src/probnum/randprocs/kernels/_polynomial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Optional
44

5-
from probnum import backend, utils
5+
from probnum import backend
66
from probnum.typing import ArrayType, IntArgType, ScalarArgType
77

88
from ._kernel import Kernel
@@ -46,8 +46,8 @@ def __init__(
4646
constant: ScalarArgType = 0.0,
4747
exponent: IntArgType = 1.0,
4848
):
49-
self.constant = utils.as_scalar(constant)
50-
self.exponent = utils.as_scalar(exponent)
49+
self.constant = backend.as_scalar(constant)
50+
self.exponent = backend.as_scalar(exponent)
5151
super().__init__(input_dim=input_dim)
5252

5353
@backend.jit_method

src/probnum/randprocs/kernels/_rational_quadratic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
from typing import Optional
44

5-
import numpy as np
6-
7-
from probnum import backend, utils
5+
from probnum import backend
86
from probnum.typing import ArrayType, IntArgType, ScalarArgType
97

108
from ._kernel import IsotropicMixin, Kernel
@@ -62,8 +60,8 @@ def __init__(
6260
lengthscale: ScalarArgType = 1.0,
6361
alpha: ScalarArgType = 1.0,
6462
):
65-
self.lengthscale = utils.as_scalar(lengthscale)
66-
self.alpha = utils.as_scalar(alpha)
63+
self.lengthscale = backend.as_scalar(lengthscale)
64+
self.alpha = backend.as_scalar(alpha)
6765
if not self.alpha > 0:
6866
raise ValueError(f"Scale mixture alpha={self.alpha} must be positive.")
6967
super().__init__(input_dim=input_dim)

src/probnum/randprocs/kernels/_white_noise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Optional
44

5-
from probnum import backend, utils
5+
from probnum import backend
66
from probnum.typing import ArrayType, IntArgType, ScalarArgType
77

88
from ._kernel import Kernel
@@ -25,7 +25,7 @@ class WhiteNoise(Kernel):
2525
"""
2626

2727
def __init__(self, input_dim: IntArgType, sigma: ScalarArgType = 1.0):
28-
self.sigma = utils.as_scalar(sigma)
28+
self.sigma = backend.as_scalar(sigma)
2929
self._sigma_sq = self.sigma ** 2
3030
super().__init__(input_dim=input_dim)
3131

src/probnum/typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
ScalarArgType = Union[int, float, complex, numbers.Number, np.number]
6060
"""Type of a public API argument for supplying a scalar value. Values of this type
6161
should always be converted into :class:`np.generic` using the function
62-
:func:`probnum.utils.as_scalar` before further internal processing."""
62+
:func:`probnum.backend.as_scalar` before further internal processing."""
6363

6464
LinearOperatorArgType = Union[
6565
np.ndarray,

src/probnum/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
__all__ = [
88
"as_colvec",
99
"atleast_1d",
10-
"as_scalar",
1110
"as_numpy_scalar",
1211
"as_shape",
1312
]

src/probnum/utils/argutils.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,9 @@
55

66
import numpy as np
77

8-
from probnum import backend
9-
from probnum.typing import (
10-
ArrayType,
11-
DTypeArgType,
12-
ScalarArgType,
13-
ShapeArgType,
14-
ShapeType,
15-
)
8+
from probnum.typing import DTypeArgType, ScalarArgType, ShapeArgType, ShapeType
169

17-
__all__ = ["as_shape", "as_numpy_scalar", "as_scalar"]
10+
__all__ = ["as_shape", "as_numpy_scalar"]
1811

1912

2013
def as_shape(x: ShapeArgType, ndim: Optional[numbers.Integral] = None) -> ShapeType:
@@ -64,20 +57,3 @@ def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic:
6457
raise ValueError("The given input is not a scalar.")
6558

6659
return np.asarray(x, dtype=dtype)[()]
67-
68-
69-
def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
70-
"""Convert a scalar into a NumPy scalar.
71-
72-
Parameters
73-
----------
74-
x
75-
Scalar value.
76-
dtype
77-
Data type of the scalar.
78-
"""
79-
80-
if backend.ndim(x) != 0:
81-
raise ValueError("The given input is not a scalar.")
82-
83-
return backend.asarray(x, dtype=dtype)[()]

0 commit comments

Comments
 (0)