Skip to content

Commit

Permalink
Add RandomRandomSeedSetter (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Mar 21, 2024
1 parent a7d717d commit 3e51114
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/coola/random/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

__all__ = ["BaseRandomSeedSetter", "TorchRandomSeedSetter"]
__all__ = ["BaseRandomSeedSetter", "RandomRandomSeedSetter", "TorchRandomSeedSetter"]

from coola.random.base import BaseRandomSeedSetter
from coola.random.random_ import RandomRandomSeedSetter
from coola.random.torch_ import TorchRandomSeedSetter
32 changes: 32 additions & 0 deletions src/coola/random/random_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
r"""Implement a random seed setter for the python standard library
``random``."""

from __future__ import annotations

__all__ = ["RandomRandomSeedSetter"]

import random

from coola.random.base import BaseRandomSeedSetter


class RandomRandomSeedSetter(BaseRandomSeedSetter):
r"""Implement a random seed setter for the python standard library
``random``.
Example usage:
```pycon
>>> from coola.random import RandomRandomSeedSetter
>>> setter = RandomRandomSeedSetter()
>>> setter.manual_seed(42)
```
"""

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}()"

def manual_seed(self, seed: int) -> None:
random.seed(seed)
28 changes: 28 additions & 0 deletions tests/unit/random/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

import random

from coola.random import RandomRandomSeedSetter

############################################
# Tests for RandomRandomSeedSetter #
############################################


def test_random_random_seed_setter_repr() -> None:
assert repr(RandomRandomSeedSetter()).startswith("RandomRandomSeedSetter(")


def test_random_random_seed_setter_str() -> None:
assert str(RandomRandomSeedSetter()).startswith("RandomRandomSeedSetter(")


def test_random_random_seed_setter_manual_seed() -> None:
seed_setter = RandomRandomSeedSetter()
seed_setter.manual_seed(42)
x1 = random.uniform(0, 1) # noqa: S311
x2 = random.uniform(0, 1) # noqa: S311
seed_setter.manual_seed(42)
x3 = random.uniform(0, 1) # noqa: S311
assert x1 == x3
assert x1 != x2

0 comments on commit 3e51114

Please sign in to comment.