Skip to content

Commit

Permalink
Add remove_keys_starting_with (#549)
Browse files Browse the repository at this point in the history
  • Loading branch information
durandtibo authored Apr 2, 2024
1 parent 90ccf82 commit b85a0c1
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 3 deletions.
7 changes: 6 additions & 1 deletion src/coola/nested/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
"convert_to_list_of_dicts",
"get_first_value",
"to_flat_dict",
"remove_keys_starting_with",
]

from coola.nested.conversion import convert_to_dict_of_lists, convert_to_list_of_dicts
from coola.nested.mapping import get_first_value, to_flat_dict
from coola.nested.mapping import (
get_first_value,
remove_keys_starting_with,
to_flat_dict,
)
33 changes: 32 additions & 1 deletion src/coola/nested/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

__all__ = ["get_first_value", "to_flat_dict"]
__all__ = ["get_first_value", "to_flat_dict", "remove_keys_starting_with"]

from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -119,3 +119,34 @@ def to_flat_dict(
else:
flat_dict[prefix] = data
return flat_dict


def remove_keys_starting_with(mapping: Mapping, prefix: str) -> dict:
r"""Remove the keys that start with a given prefix.
Args:
mapping: The original mapping.
prefix: The prefix used to filter the keys.
Returns:
A new dict without the removed keys.
Example usage:
```pycon
>>> from coola.nested import remove_keys_starting_with
>>> remove_keys_starting_with(
... {"key": 1, "key.abc": 2, "abc": 3, "abc.key": 4, 1: 5, (2, 3): 6},
... "key",
... )
{'abc': 3, 'abc.key': 4, 1: 5, (2, 3): 6}
```
"""
out = {}
for key, value in mapping.items():
if isinstance(key, str) and key.startswith(prefix):
continue
out[key] = value
return out
34 changes: 33 additions & 1 deletion tests/unit/nested/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from coola import objects_are_equal
from coola.nested import get_first_value, to_flat_dict
from coola.nested import get_first_value, remove_keys_starting_with, to_flat_dict
from coola.testing import numpy_available, torch_available
from coola.utils import is_numpy_available, is_torch_available

Expand Down Expand Up @@ -270,3 +270,35 @@ def test_to_flat_dict_tensor() -> None:
@numpy_available
def test_to_flat_dict_numpy_ndarray() -> None:
assert objects_are_equal(to_flat_dict(np.zeros((2, 3))), {None: np.zeros((2, 3))})


###############################################
# Tests for remove_keys_starting_with #
###############################################


def test_remove_keys_starting_with_empty() -> None:
assert remove_keys_starting_with({}, "key") == {}


def test_remove_keys_starting_with() -> None:
assert remove_keys_starting_with(
{"key": 1, "key.abc": 2, "abc": 3, "abc.key": 4, 1: 5, (2, 3): 6}, "key"
) == {
"abc": 3,
"abc.key": 4,
1: 5,
(2, 3): 6,
}


def test_remove_keys_starting_with_another_key() -> None:
assert remove_keys_starting_with(
{"key": 1, "key.abc": 2, "abc": 3, "abc.key": 4, 1: 5, (2, 3): 6}, "key."
) == {
"key": 1,
"abc": 3,
"abc.key": 4,
1: 5,
(2, 3): 6,
}

0 comments on commit b85a0c1

Please sign in to comment.