Skip to content

Commit

Permalink
Add NumpyRandomSeedSetter (#526)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Mar 21, 2024
1 parent 3e51114 commit 8b269c2
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/coola/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

from __future__ import annotations

__all__ = ["BaseRandomSeedSetter", "RandomRandomSeedSetter", "TorchRandomSeedSetter"]
__all__ = [
"BaseRandomSeedSetter",
"NumpyRandomSeedSetter",
"RandomRandomSeedSetter",
"TorchRandomSeedSetter",
]

from coola.random.base import BaseRandomSeedSetter
from coola.random.numpy_ import NumpyRandomSeedSetter
from coola.random.random_ import RandomRandomSeedSetter
from coola.random.torch_ import TorchRandomSeedSetter
43 changes: 43 additions & 0 deletions src/coola/random/numpy_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
r"""Implement a random seed setter for NumPy."""

from __future__ import annotations

__all__ = ["NumpyRandomSeedSetter"]

from unittest.mock import Mock

from coola.random.base import BaseRandomSeedSetter
from coola.utils import check_numpy, is_numpy_available

if is_numpy_available():
import numpy as np
else: # pragma: no cover
np = Mock()


class NumpyRandomSeedSetter(BaseRandomSeedSetter):
r"""Implement a random seed setter for the library ``numpy``.
The seed must be between ``0`` and ``2**32 - 1``, so a modulo
operator to convert an integer to an integer between ``0`` and
``2**32 - 1``.
Example usage:
```pycon
>>> from coola.random import NumpyRandomSeedSetter
>>> setter = NumpyRandomSeedSetter()
>>> setter.manual_seed(42)
```
"""

def __init__(self) -> None:
check_numpy()

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}()"

def manual_seed(self, seed: int) -> None:
np.random.seed(seed % 2**32)
46 changes: 46 additions & 0 deletions tests/unit/random/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from unittest.mock import patch

import pytest

from coola.random import NumpyRandomSeedSetter
from coola.testing import numpy_available
from coola.utils import is_numpy_available

if is_numpy_available():
import numpy as np

###########################################
# Tests for NumpyRandomSeedSetter #
###########################################


@numpy_available
def test_numpy_random_seed_setter_repr() -> None:
assert repr(NumpyRandomSeedSetter()).startswith("NumpyRandomSeedSetter(")


@numpy_available
def test_numpy_random_seed_setter_str() -> None:
assert str(NumpyRandomSeedSetter()).startswith("NumpyRandomSeedSetter(")


@numpy_available
def test_numpy_random_seed_setter_manual_seed() -> None:
seed_setter = NumpyRandomSeedSetter()
seed_setter.manual_seed(42)
x1 = np.random.randn(4, 6)
x2 = np.random.randn(4, 6)
seed_setter.manual_seed(42)
x3 = np.random.randn(4, 6)
assert np.array_equal(x1, x3)
assert not np.array_equal(x1, x2)


def test_numpy_random_seed_setter_no_numpy() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
):
NumpyRandomSeedSetter()

0 comments on commit 8b269c2

Please sign in to comment.