Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update NumpyRandomManager #528

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/coola/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

__all__ = [
"BaseRandomSeedSetter",
"NumpyRandomSeedSetter",
"NumpyRandomManager",
"RandomRandomSeedSetter",
"TorchRandomManager",
]

from coola.random.base import BaseRandomSeedSetter
from coola.random.numpy_ import NumpyRandomSeedSetter
from coola.random.numpy_ import NumpyRandomManager
from coola.random.random_ import RandomRandomSeedSetter
from coola.random.torch_ import TorchRandomManager
14 changes: 10 additions & 4 deletions src/coola/random/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

__all__ = ["NumpyRandomSeedSetter"]
__all__ = ["NumpyRandomManager"]

from unittest.mock import Mock

Expand All @@ -15,7 +15,7 @@
np = Mock()


class NumpyRandomSeedSetter(BaseRandomSeedSetter):
class NumpyRandomManager(BaseRandomSeedSetter):
r"""Implement a random seed setter for the library ``numpy``.

The seed must be between ``0`` and ``2**32 - 1``, so a modulo
Expand All @@ -26,8 +26,8 @@ class NumpyRandomSeedSetter(BaseRandomSeedSetter):

```pycon

>>> from coola.random import NumpyRandomSeedSetter
>>> setter = NumpyRandomSeedSetter()
>>> from coola.random import NumpyRandomManager
>>> setter = NumpyRandomManager()
>>> setter.manual_seed(42)

```
Expand All @@ -39,5 +39,11 @@ def __init__(self) -> None:
def __repr__(self) -> str:
return f"{self.__class__.__qualname__}()"

def get_rng_state(self) -> dict | tuple:
return np.random.get_state()

def manual_seed(self, seed: int) -> None:
np.random.seed(seed % 2**32)

def set_rng_state(self, state: dict | tuple) -> None:
np.random.set_state(state)
43 changes: 31 additions & 12 deletions tests/unit/random/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,38 @@

import pytest

from coola.random import NumpyRandomSeedSetter
from coola.random import NumpyRandomManager
from coola.testing import numpy_available
from coola.utils import is_numpy_available

if is_numpy_available():
import numpy as np

###########################################
# Tests for NumpyRandomSeedSetter #
###########################################
########################################
# Tests for NumpyRandomManager #
########################################


@numpy_available
def test_numpy_random_seed_setter_repr() -> None:
assert repr(NumpyRandomSeedSetter()).startswith("NumpyRandomSeedSetter(")
def test_numpy_random_manager_repr() -> None:
assert repr(NumpyRandomManager()).startswith("NumpyRandomManager(")


@numpy_available
def test_numpy_random_seed_setter_str() -> None:
assert str(NumpyRandomSeedSetter()).startswith("NumpyRandomSeedSetter(")
def test_numpy_random_manager_str() -> None:
assert str(NumpyRandomManager()).startswith("NumpyRandomManager(")


@numpy_available
def test_numpy_random_seed_setter_manual_seed() -> None:
seed_setter = NumpyRandomSeedSetter()
def test_numpy_random_manager_get_rng_state() -> None:
rng = NumpyRandomManager()
state = rng.get_rng_state()
assert isinstance(state, (tuple, dict))


@numpy_available
def test_numpy_random_manager_manual_seed() -> None:
seed_setter = NumpyRandomManager()
seed_setter.manual_seed(42)
x1 = np.random.randn(4, 6)
x2 = np.random.randn(4, 6)
Expand All @@ -38,9 +45,21 @@ def test_numpy_random_seed_setter_manual_seed() -> None:
assert not np.array_equal(x1, x2)


def test_numpy_random_seed_setter_no_numpy() -> None:
@numpy_available
def test_numpy_random_manager_set_rng_state() -> None:
seed_setter = NumpyRandomManager()
state = seed_setter.get_rng_state()
x1 = np.random.randn(4, 6)
x2 = np.random.randn(4, 6)
seed_setter.set_rng_state(state)
x3 = np.random.randn(4, 6)
assert np.array_equal(x1, x3)
assert not np.array_equal(x1, x2)


def test_numpy_random_manager_no_numpy() -> None:
with (
patch("coola.utils.imports.is_numpy_available", lambda: False),
pytest.raises(RuntimeError, match="`numpy` package is required but not installed."),
):
NumpyRandomSeedSetter()
NumpyRandomManager()
Loading