Skip to content

Commit 5ea8668

Browse files
Cleanup of probnum's type aliases
1 parent e1de948 commit 5ea8668

File tree

6 files changed

+37
-38
lines changed

6 files changed

+37
-38
lines changed

src/probnum/backend/_core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from probnum import backend as _backend
2-
from probnum.typing import ArrayType, DTypeArgType, ScalarArgType
2+
from probnum.typing import ArrayType, DTypeArgType, ScalarLike
33

44
if _backend.BACKEND is _backend.Backend.NUMPY:
55
from . import _numpy as _core
@@ -79,7 +79,7 @@
7979
jit_method = _core.jit_method
8080

8181

82-
def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
82+
def as_scalar(x: ScalarLike, dtype: DTypeArgType = None) -> ArrayType:
8383
"""Convert a scalar into a NumPy scalar.
8484
8585
Parameters

src/probnum/backend/random/_jax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax
66
from jax import numpy as jnp
77

8-
from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType
8+
from probnum.typing import DTypeArgType, FloatLike, ShapeLike
99

1010

1111
def seed(seed: Optional[int]) -> jnp.ndarray:
@@ -28,9 +28,9 @@ def standard_normal(seed: jnp.ndarray, shape=(), dtype=jnp.double):
2828

2929
def gamma(
3030
seed: jnp.ndarray,
31-
shape_param: FloatArgType,
32-
scale_param: FloatArgType = 1.0,
33-
shape: ShapeArgType = (),
31+
shape_param: FloatLike,
32+
scale_param: FloatLike = 1.0,
33+
shape: ShapeLike = (),
3434
dtype: DTypeArgType = jnp.double,
3535
):
3636
return (
@@ -43,7 +43,7 @@ def gamma(
4343
def uniform_so_group(
4444
seed: jnp.ndarray,
4545
n: int,
46-
shape: ShapeArgType = (),
46+
shape: ShapeLike = (),
4747
dtype: DTypeArgType = jnp.double,
4848
) -> jnp.ndarray:
4949
if n == 1:

src/probnum/backend/random/_numpy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55

6-
from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType
6+
from probnum.typing import DTypeArgType, FloatLike, ShapeLike
77

88

99
def seed(seed: Optional[int]) -> np.random.SeedSequence:
@@ -21,17 +21,17 @@ def split(
2121

2222
def standard_normal(
2323
seed: np.random.SeedSequence,
24-
shape: ShapeArgType = (),
24+
shape: ShapeLike = (),
2525
dtype: DTypeArgType = np.double,
2626
) -> np.ndarray:
2727
return _make_rng(seed).standard_normal(size=shape, dtype=dtype)
2828

2929

3030
def gamma(
3131
seed: np.random.SeedSequence,
32-
shape_param: FloatArgType,
33-
scale_param: FloatArgType = 1.0,
34-
shape: ShapeArgType = (),
32+
shape_param: FloatLike,
33+
scale_param: FloatLike = 1.0,
34+
shape: ShapeLike = (),
3535
dtype: DTypeArgType = np.double,
3636
) -> np.ndarray:
3737
return (
@@ -43,7 +43,7 @@ def gamma(
4343
def uniform_so_group(
4444
seed: np.random.SeedSequence,
4545
n: int,
46-
shape: ShapeArgType = (),
46+
shape: ShapeLike = (),
4747
dtype: DTypeArgType = np.double,
4848
) -> np.ndarray:
4949
if n == 1:

src/probnum/backend/random/_torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch.distributions.utils import broadcast_all
66

7-
from probnum.typing import DTypeArgType, ShapeArgType
7+
from probnum.typing import DTypeArgType, ShapeLike
88

99
_RNG_STATE_SIZE = torch.Generator().get_state().shape[0]
1010

@@ -51,7 +51,7 @@ def gamma(
5151
def uniform_so_group(
5252
seed: np.random.SeedSequence,
5353
n: int,
54-
shape: ShapeArgType = (),
54+
shape: ShapeLike = (),
5555
dtype: DTypeArgType = torch.double,
5656
) -> torch.Tensor:
5757
if n == 1:

src/probnum/randvars/_random_variable.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from probnum import backend, utils as _utils
1010
from probnum.typing import (
1111
ArrayIndicesLike,
12-
ArrayLike,
1312
ArrayType,
1413
DTypeLike,
1514
SeedType,

src/probnum/typing.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,13 @@
2929
# Array Utilities
3030
ShapeType = Tuple[int, ...]
3131

32-
# Backend Types
33-
ArrayType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"]
34-
ScalarType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"]
32+
# Scalars, Arrays and Matrices
33+
ScalarType = "probnum.backend.ndarray"
34+
MatrixType = Union["probnum.backend.ndarray", "probnum.linops.LinearOperator"]
3535

36+
# Random Number Generation
3637
SeedType = Union[np.random.SeedSequence, "jax.random.PRNGKey"]
3738

38-
# ProbNum Types
39-
MatrixType = Union[ArrayType, "probnum.linops.LinearOperator"]
40-
41-
# Scalars, Arrays and Matrices
42-
ScalarType = np.number
43-
MatrixType = Union[np.ndarray, "probnum.linops.LinearOperator"]
4439

4540
########################################################################################
4641
# Argument Types
@@ -64,39 +59,39 @@
6459
"""Type of a public API argument for supplying a shape.
6560
6661
Values of this type should always be converted into :class:`ShapeType` using the
67-
function :func:`probnum.backend.as_scalar` before further internal processing."""
62+
function :func:`probnum.backend.as_shape` before further internal processing."""
6863

69-
DTypeLike = _NumPyDTypeLike
64+
DTypeLike = Union[_NumPyDTypeLike, "jax.numpy.dtype", "torch.dtype"]
7065
"""Type of a public API argument for supplying an array's dtype.
7166
72-
Values of this type should always be converted into :class:`np.dtype`\\ s before further
73-
internal processing."""
67+
Values of this type should always be converted into :class:`backend.dtype`\\ s using the
68+
function :func:`probnum.backend.as_dtype` before further internal processing."""
7469

7570
_ArrayIndexLike = Union[
7671
int,
7772
slice,
7873
type(Ellipsis),
7974
None,
80-
np.newaxis,
81-
np.ndarray,
75+
"probnum.backend.newaxis",
76+
"probnum.backend.ndarray",
8277
]
8378
ArrayIndicesLike = Union[_ArrayIndexLike, Tuple[_ArrayIndexLike, ...]]
8479
"""Type of the argument to the :meth:`__getitem__` method of a NumPy-like array type
85-
such as :class:`np.ndarray`, :class:`probnum.linops.LinearOperator` or
80+
such as :class:`probnum.backend.ndarray`, :class:`probnum.linops.LinearOperator` or
8681
:class:`probnum.randvars.RandomVariable`."""
8782

8883
# Scalars, Arrays and Matrices
89-
ScalarLike = Union[int, float, complex, numbers.Number, np.number]
84+
ScalarLike = Union[ScalarType, int, float, complex, numbers.Number, np.number]
9085
"""Type of a public API argument for supplying a scalar value.
9186
92-
Values of this type should always be converted into :class:`np.number`\\ s using the
93-
function :func:`probnum.utils.as_scalar` before further internal processing."""
87+
Values of this type should always be converted into :class:`ScalarType`\\ s using
88+
the function :func:`probnum.backend.as_scalar` before further internal processing."""
9489

9590
ArrayLike = Union[_NumPyArrayLike, "jax.numpy.ndarray", "torch.Tensor"]
9691
"""Type of a public API argument for supplying an array.
9792
98-
Values of this type should always be converted into :class:`np.ndarray`\\ s using
99-
the function :func:`np.asarray` before further internal processing."""
93+
Values of this type should always be converted into :class:`backend.ndarray`\\ s using
94+
the function :func:`probnum.backend.as_array` before further internal processing."""
10095

10196
LinearOperatorLike = Union[
10297
ArrayLike,
@@ -106,10 +101,15 @@
106101
"""Type of a public API argument for supplying a finite-dimensional linear operator.
107102
108103
Values of this type should always be converted into :class:`probnum.linops.\\
109-
LinearOperator`\\ s using the function :func:`probnum.linops.aslinop` before further
104+
LinearOperator`\\ s using the function :func:`probnum.linops.as_linop` before further
110105
internal processing."""
111106

107+
# Random Number Generation
112108
SeedLike = Optional[int]
109+
"""Type of a public API argument for supplying the seed of a random number generator.
110+
111+
Values of this type should always be converted to :class:`SeedType` using the function
112+
:func:`probnum.backend.random.seed` before further internal processing."""
113113

114114
########################################################################################
115115
# Other Types

0 commit comments

Comments
 (0)