From 6489162c065040138a7870a9b6c485358e5148b0 Mon Sep 17 00:00:00 2001 From: Thibaut Durand Date: Mon, 1 Apr 2024 17:17:44 -0700 Subject: [PATCH] Add mapping manipulation functions (#548) --- src/coola/nested/__init__.py | 3 + src/coola/nested/conversion.py | 5 +- src/coola/nested/mapping.py | 121 +++++++++++++ tests/unit/nested/test_mapping.py | 272 ++++++++++++++++++++++++++++++ 4 files changed, 397 insertions(+), 4 deletions(-) create mode 100644 src/coola/nested/mapping.py create mode 100644 tests/unit/nested/test_mapping.py diff --git a/src/coola/nested/__init__.py b/src/coola/nested/__init__.py index 290e178d..3654f53b 100644 --- a/src/coola/nested/__init__.py +++ b/src/coola/nested/__init__.py @@ -5,6 +5,9 @@ __all__ = [ "convert_to_dict_of_lists", "convert_to_list_of_dicts", + "get_first_value", + "to_flat_dict", ] from coola.nested.conversion import convert_to_dict_of_lists, convert_to_list_of_dicts +from coola.nested.mapping import get_first_value, to_flat_dict diff --git a/src/coola/nested/conversion.py b/src/coola/nested/conversion.py index 5ef7058d..5f869966 100644 --- a/src/coola/nested/conversion.py +++ b/src/coola/nested/conversion.py @@ -2,10 +2,7 @@ from __future__ import annotations -__all__ = [ - "convert_to_dict_of_lists", - "convert_to_list_of_dicts", -] +__all__ = ["convert_to_dict_of_lists", "convert_to_list_of_dicts"] from typing import TYPE_CHECKING diff --git a/src/coola/nested/mapping.py b/src/coola/nested/mapping.py new file mode 100644 index 00000000..441b7820 --- /dev/null +++ b/src/coola/nested/mapping.py @@ -0,0 +1,121 @@ +r"""Contain some utility functions to manipulate mappings.""" + +from __future__ import annotations + +__all__ = ["get_first_value", "to_flat_dict"] + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Mapping + + +def get_first_value(data: Mapping) -> Any: + r"""Get the first value of a mapping. + + Args: + data: The input mapping. + + Returns: + The first value in the mapping. + + Raises: + ValueError: if the mapping is empty. + + Example usage: + + ```pycon + + >>> from coola.nested import get_first_value + >>> get_first_value({"key1": 1, "key2": 2}) + 1 + + ``` + """ + if not data: + msg = "First value cannot be returned because the mapping is empty" + raise ValueError(msg) + return data[next(iter(data))] + + +def to_flat_dict( + data: Any, + prefix: str | None = None, + separator: str = ".", + to_str: type[object] | tuple[type[object], ...] | None = None, +) -> dict[str, Any]: + r"""Return a flat representation of a nested dict with the dot + format. + + Args: + data: The nested dict to flat. + prefix: The prefix to use to generate the name of the key. + ``None`` means no prefix. + separator: The separator to concatenate keys of nested + collections. + to_str: The data types which will not be flattened out, + instead converted to string. + + Returns: + The flatted dictionary. + + Example usage: + + ```pycon + + >>> from coola.nested import to_flat_dict + >>> data = { + ... "str": "def", + ... "module": { + ... "component": { + ... "float": 3.5, + ... "int": 2, + ... }, + ... }, + ... } + >>> to_flat_dict(data) + {'str': 'def', 'module.component.float': 3.5, 'module.component.int': 2} + >>> # Example with lists (also works with tuple) + >>> data = { + ... "module": [[1, 2, 3], {"bool": True}], + ... "str": "abc", + ... } + >>> to_flat_dict(data) + {'module.0.0': 1, 'module.0.1': 2, 'module.0.2': 3, 'module.1.bool': True, 'str': 'abc'} + >>> # Example with lists with to_str=(list) (also works with tuple) + >>> data = { + ... "module": [[1, 2, 3], {"bool": True}], + ... "str": "abc", + ... } + >>> to_flat_dict(data) + {'module.0.0': 1, 'module.0.1': 2, 'module.0.2': 3, 'module.1.bool': True, 'str': 'abc'} + + ``` + """ + flat_dict = {} + to_str = to_str or () + if isinstance(data, to_str): + flat_dict[prefix] = str(data) + elif isinstance(data, dict): + for key, value in data.items(): + flat_dict.update( + to_flat_dict( + value, + prefix=f"{prefix}{separator}{key}" if prefix else key, + separator=separator, + to_str=to_str, + ) + ) + elif isinstance(data, (list, tuple)): + for i, value in enumerate(data): + flat_dict.update( + to_flat_dict( + value, + prefix=f"{prefix}{separator}{i}" if prefix else str(i), + separator=separator, + to_str=to_str, + ) + ) + else: + flat_dict[prefix] = data + return flat_dict diff --git a/tests/unit/nested/test_mapping.py b/tests/unit/nested/test_mapping.py new file mode 100644 index 00000000..070e2aa1 --- /dev/null +++ b/tests/unit/nested/test_mapping.py @@ -0,0 +1,272 @@ +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from coola import objects_are_equal +from coola.nested import get_first_value, to_flat_dict +from coola.testing import numpy_available, torch_available +from coola.utils import is_numpy_available, is_torch_available + +if is_numpy_available(): + import numpy as np +else: # pragma: no cover + np = Mock() + +if is_torch_available(): + import torch +else: # pragma: no cover + torch = Mock() + +##################################### +# Tests for get_first_value # +##################################### + + +def test_get_first_value_empty() -> None: + with pytest.raises( + ValueError, match="First value cannot be returned because the mapping is empty" + ): + get_first_value({}) + + +def test_get_first_value() -> None: + assert get_first_value({"key1": 1, "key2": 2}) == 1 + + +################################## +# Tests for to_flat_dict # +################################## + + +def test_to_flat_dict_flat_dict() -> None: + flatten_dict = to_flat_dict( + { + "bool": False, + "float": 3.5, + "int": 2, + "str": "abc", + } + ) + assert flatten_dict == { + "bool": False, + "float": 3.5, + "int": 2, + "str": "abc", + } + + +def test_to_flat_dict_nested_dict_str() -> None: + flatten_dict = to_flat_dict({"a": "a", "b": {"c": "c"}, "d": {"e": {"f": "f"}}}) + assert flatten_dict == {"a": "a", "b.c": "c", "d.e.f": "f"} + + +def test_to_flat_dict_nested_dict_multiple_types() -> None: + flatten_dict = to_flat_dict( + { + "module": { + "bool": False, + "float": 3.5, + "int": 2, + }, + "str": "abc", + } + ) + assert flatten_dict == { + "module.bool": False, + "module.float": 3.5, + "module.int": 2, + "str": "abc", + } + + +def test_to_flat_dict_data_empty_key() -> None: + flatten_dict = to_flat_dict( + { + "module": {}, + "str": "abc", + } + ) + assert flatten_dict == {"str": "abc"} + + +def test_to_flat_dict_double_data() -> None: + flatten_dict = to_flat_dict( + { + "str": "def", + "module": { + "component": { + "float": 3.5, + "int": 2, + }, + }, + } + ) + assert flatten_dict == { + "module.component.float": 3.5, + "module.component.int": 2, + "str": "def", + } + + +def test_to_flat_dict_double_data_2() -> None: + flatten_dict = to_flat_dict( + { + "module": { + "component_a": { + "float": 3.5, + "int": 2, + }, + "component_b": { + "param_a": 1, + "param_b": 2, + }, + "str": "abc", + }, + } + ) + assert flatten_dict == { + "module.component_a.float": 3.5, + "module.component_a.int": 2, + "module.component_b.param_a": 1, + "module.component_b.param_b": 2, + "module.str": "abc", + } + + +def test_to_flat_dict_list() -> None: + flatten_dict = to_flat_dict([2, "abc", True, 3.5]) + assert flatten_dict == { + "0": 2, + "1": "abc", + "2": True, + "3": 3.5, + } + + +def test_to_flat_dict_dict_with_list() -> None: + flatten_dict = to_flat_dict( + { + "module": [2, "abc", True, 3.5], + "str": "abc", + } + ) + assert flatten_dict == { + "module.0": 2, + "module.1": "abc", + "module.2": True, + "module.3": 3.5, + "str": "abc", + } + + +def test_to_flat_dict_with_more_complex_list() -> None: + flatten_dict = to_flat_dict( + { + "module": [[1, 2, 3], {"bool": True}], + "str": "abc", + } + ) + assert flatten_dict == { + "module.0.0": 1, + "module.0.1": 2, + "module.0.2": 3, + "module.1.bool": True, + "str": "abc", + } + + +def test_to_flat_dict_tuple() -> None: + flatten_dict = to_flat_dict( + { + "module": (2, "abc", True, 3.5), + "str": "abc", + } + ) + assert flatten_dict == { + "module.0": 2, + "module.1": "abc", + "module.2": True, + "module.3": 3.5, + "str": "abc", + } + + +def test_to_flat_dict_with_complex_tuple() -> None: + flatten_dict = to_flat_dict( + { + "module": ([1, 2, 3], {"bool": True}), + "str": "abc", + } + ) + assert flatten_dict == { + "module.0.0": 1, + "module.0.1": 2, + "module.0.2": 3, + "module.1.bool": True, + "str": "abc", + } + + +@pytest.mark.parametrize("separator", [".", "/", "@", "[SEP]"]) +def test_to_flat_dict_separator(separator: str) -> None: + flatten_dict = to_flat_dict( + { + "str": "def", + "module": { + "component": { + "float": 3.5, + "int": 2, + }, + }, + }, + separator=separator, + ) + assert flatten_dict == { + f"module{separator}component{separator}float": 3.5, + f"module{separator}component{separator}int": 2, + "str": "def", + } + + +def test_to_flat_dict_to_str_tuple() -> None: + flatten_dict = to_flat_dict( + { + "module": (2, "abc", True, 3.5), + "str": "abc", + }, + to_str=tuple, + ) + assert flatten_dict == { + "module": "(2, 'abc', True, 3.5)", + "str": "abc", + } + + +def test_to_flat_dict_to_str_tuple_and_list() -> None: + flatten_dict = to_flat_dict( + { + "module1": (2, "abc", True, 3.5), + "module2": [1, 2, 3], + "str": "abc", + }, + to_str=(list, tuple), + ) + assert flatten_dict == { + "module1": "(2, 'abc', True, 3.5)", + "module2": "[1, 2, 3]", + "str": "abc", + } + + +@torch_available +def test_to_flat_dict_tensor() -> None: + assert objects_are_equal( + to_flat_dict({"tensor": torch.ones(2, 3)}), {"tensor": torch.ones(2, 3)} + ) + + +@numpy_available +def test_to_flat_dict_numpy_ndarray() -> None: + assert objects_are_equal(to_flat_dict(np.zeros((2, 3))), {None: np.zeros((2, 3))})