Skip to content

Commit

Permalink
Validate datasets versions (#4347)
Browse files Browse the repository at this point in the history
* Implemented _validate_versions method

Signed-off-by: Elena Khaustova <[email protected]>

* Added _validate_versions calls

Signed-off-by: Elena Khaustova <[email protected]>

* Updated error descriptions

Signed-off-by: Elena Khaustova <[email protected]>

* Added validation to the old catalog

Signed-off-by: Elena Khaustova <[email protected]>

* Fixed linter

Signed-off-by: Elena Khaustova <[email protected]>

* Implemented unit tests for KedroDataCatalog

Signed-off-by: Elena Khaustova <[email protected]>

* Removed odd comments

Signed-off-by: Elena Khaustova <[email protected]>

* Implemented tests for DataCatalog

Signed-off-by: Elena Khaustova <[email protected]>

* Added docstrings

Signed-off-by: Elena Khaustova <[email protected]>

* Added release notes

Signed-off-by: Elena Khaustova <[email protected]>

* Added CachedDataset case

Signed-off-by: Elena Khaustova <[email protected]>

* Updated release notes

Signed-off-by: Elena Khaustova <[email protected]>

* Added tests for CachedDataset use case

Signed-off-by: Elena Khaustova <[email protected]>

* Fixed typos

Signed-off-by: Elena Khaustova <[email protected]>

* Updated TODOs

Signed-off-by: Elena Khaustova <[email protected]>

---------

Signed-off-by: Elena Khaustova <[email protected]>
  • Loading branch information
ElenaKhaustova authored Nov 28, 2024
1 parent b53d365 commit 7b24af7
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 10 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
## Bug fixes and other changes
* Added validation to ensure dataset versions consistency across catalog.
## Breaking changes to the API
## Documentation changes
## Community contributions
Expand Down
71 changes: 67 additions & 4 deletions kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ class DatasetError(Exception):


class DatasetNotFoundError(DatasetError):
"""``DatasetNotFoundError`` raised by ``DataCatalog`` class in case of
trying to use a non-existing dataset.
"""``DatasetNotFoundError`` raised by ```DataCatalog`` and ``KedroDataCatalog``
classes in case of trying to use a non-existing dataset.
"""

pass


class DatasetAlreadyExistsError(DatasetError):
"""``DatasetAlreadyExistsError`` raised by ``DataCatalog`` class in case
of trying to add a dataset which already exists in the ``DataCatalog``.
"""``DatasetAlreadyExistsError`` raised by ```DataCatalog`` and ``KedroDataCatalog``
classes in case of trying to add a dataset which already exists in the ``DataCatalog``.
"""

pass
Expand All @@ -94,6 +94,15 @@ class VersionNotFoundError(DatasetError):
pass


class VersionAlreadyExistsError(DatasetError):
"""``VersionAlreadyExistsError`` raised by ``DataCatalog`` and ``KedroDataCatalog``
classes when attempting to add a dataset to a catalog with a save version
that conflicts with the save version already set for the catalog.
"""

pass


_DI = TypeVar("_DI")
_DO = TypeVar("_DO")

Expand Down Expand Up @@ -955,3 +964,57 @@ def confirm(self, name: str) -> None:
def shallow_copy(self, extra_dataset_patterns: Patterns | None = None) -> _C:
"""Returns a shallow copy of the current object."""
...


def _validate_versions(
datasets: dict[str, AbstractDataset] | None,
load_versions: dict[str, str],
save_version: str | None,
) -> tuple[dict[str, str], str | None]:
"""Validates and synchronises dataset versions for loading and saving.
Ensures consistency of dataset versions across a catalog, particularly
for versioned datasets. It updates load versions and validates that all
save versions are consistent.
Args:
datasets: A dictionary mapping dataset names to their instances.
if None, no validation occurs.
load_versions: A mapping between dataset names and versions
to load.
save_version: Version string to be used for ``save`` operations
by all datasets with versioning enabled.
Returns:
Updated ``load_versions`` with load versions specified in the ``datasets``
and resolved ``save_version``.
Raises:
VersionAlreadyExistsError: If a dataset's save version conflicts with
the catalog's save version.
"""
if not datasets:
return load_versions, save_version

cur_load_versions = load_versions.copy()
cur_save_version = save_version

for ds_name, ds in datasets.items():
# TODO: Move to kedro/io/kedro_data_catalog.py when removing DataCatalog
# TODO: Make it a protected static method for KedroDataCatalog
# TODO: Replace with isinstance(ds, CachedDataset) - current implementation avoids circular import
cur_ds = ds._dataset if ds.__class__.__name__ == "CachedDataset" else ds # type: ignore[attr-defined]

if isinstance(cur_ds, AbstractVersionedDataset) and cur_ds._version:
if cur_ds._version.load:
cur_load_versions[ds_name] = cur_ds._version.load
if cur_ds._version.save:
cur_save_version = cur_save_version or cur_ds._version.save
if cur_save_version != cur_ds._version.save:
raise VersionAlreadyExistsError(
f"Cannot add a dataset `{ds_name}` with `{cur_ds._version.save}` save version. "
f"Save version set for the catalog is `{cur_save_version}`"
f"All datasets in the catalog must have the same save version."
)

return cur_load_versions, cur_save_version
12 changes: 8 additions & 4 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DatasetError,
DatasetNotFoundError,
Version,
_validate_versions,
generate_timestamp,
)
from kedro.io.memory_dataset import MemoryDataset
Expand Down Expand Up @@ -160,20 +161,20 @@ def __init__( # noqa: PLR0913
>>> catalog = DataCatalog(datasets={'cars': cars})
"""
self._config_resolver = config_resolver or CatalogConfigResolver()

# Kept to avoid breaking changes
if not config_resolver:
self._config_resolver._dataset_patterns = dataset_patterns or {}
self._config_resolver._default_pattern = default_pattern or {}

self._load_versions, self._save_version = _validate_versions(
datasets, load_versions or {}, save_version
)

self._datasets: dict[str, AbstractDataset] = {}
self.datasets: _FrozenDatasets | None = None

self.add_all(datasets or {})

self._load_versions = load_versions or {}
self._save_version = save_version

self._use_rich_markup = _has_rich_handler()

if feed_dict:
Expand Down Expand Up @@ -506,6 +507,9 @@ def add(
raise DatasetAlreadyExistsError(
f"Dataset '{dataset_name}' has already been registered"
)
self._load_versions, self._save_version = _validate_versions(
{dataset_name: dataset}, self._load_versions, self._save_version
)
self._datasets[dataset_name] = dataset
self.datasets = _FrozenDatasets(self.datasets, {dataset_name: dataset})

Expand Down
9 changes: 7 additions & 2 deletions kedro/io/kedro_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DatasetError,
DatasetNotFoundError,
Version,
_validate_versions,
generate_timestamp,
)
from kedro.io.memory_dataset import MemoryDataset
Expand Down Expand Up @@ -98,8 +99,9 @@ def __init__(
self._config_resolver = config_resolver or CatalogConfigResolver()
self._datasets = datasets or {}
self._lazy_datasets: dict[str, _LazyDataset] = {}
self._load_versions = load_versions or {}
self._save_version = save_version
self._load_versions, self._save_version = _validate_versions(
datasets, load_versions or {}, save_version
)

self._use_rich_markup = _has_rich_handler()

Expand Down Expand Up @@ -218,6 +220,9 @@ def __setitem__(self, key: str, value: Any) -> None:
if key in self._datasets:
self._logger.warning("Replacing dataset '%s'", key)
if isinstance(value, AbstractDataset):
self._load_versions, self._save_version = _validate_versions(
{key: value}, self._load_versions, self._save_version
)
self._datasets[key] = value
elif isinstance(value, _LazyDataset):
self._lazy_datasets[key] = value
Expand Down
22 changes: 22 additions & 0 deletions tests/io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest
from kedro_datasets.pandas import CSVDataset

from kedro.io import CachedDataset, Version


@pytest.fixture
def dummy_numpy_array():
Expand Down Expand Up @@ -34,6 +36,26 @@ def dataset(filepath):
return CSVDataset(filepath=filepath, save_args={"index": False})


@pytest.fixture
def dataset_versioned(filepath):
return CSVDataset(
filepath=filepath,
save_args={"index": False},
version=Version(load="test_load_version.csv", save="test_save_version.csv"),
)


@pytest.fixture
def cached_dataset_versioned(filepath):
return CachedDataset(
dataset=CSVDataset(
filepath=filepath,
save_args={"index": False},
version=Version(load="test_load_version.csv", save="test_save_version.csv"),
)
)


@pytest.fixture
def correct_config(filepath):
return {
Expand Down
77 changes: 77 additions & 0 deletions tests/io/test_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_DEFAULT_PACKAGES,
VERSION_FORMAT,
Version,
VersionAlreadyExistsError,
generate_timestamp,
parse_dataset_definition,
)
Expand Down Expand Up @@ -753,6 +754,82 @@ def test_no_versions_with_cloud_protocol(self, monkeypatch):
with pytest.raises(DatasetError, match=pattern):
versioned_dataset.load()

def test_redefine_save_version_via_catalog(self, correct_config, dataset_versioned):
"""Test redefining save version when it is already set"""
# Version is set automatically for the catalog
catalog = DataCatalog.from_config(**correct_config)
with pytest.raises(VersionAlreadyExistsError):
catalog.add("ds_versioned", dataset_versioned)

# Version is set manually for the catalog
correct_config["catalog"]["boats"]["versioned"] = True
catalog = DataCatalog.from_config(**correct_config)
with pytest.raises(VersionAlreadyExistsError):
catalog.add("ds_versioned", dataset_versioned)

def test_set_load_and_save_versions(self, correct_config, dataset_versioned):
"""Test setting load and save versions for catalog based on dataset's versions provided"""
catalog = DataCatalog(datasets={"ds_versioned": dataset_versioned})

assert catalog._load_versions["ds_versioned"] == dataset_versioned._version.load
assert catalog._save_version == dataset_versioned._version.save

def test_set_same_versions(self, correct_config, dataset_versioned):
"""Test setting the same load and save versions for catalog based on dataset's versions provided"""
catalog = DataCatalog(datasets={"ds_versioned": dataset_versioned})
catalog.add("ds_same_versions", dataset_versioned)

assert catalog._load_versions["ds_versioned"] == dataset_versioned._version.load
assert catalog._save_version == dataset_versioned._version.save

def test_redefine_load_version(self, correct_config, dataset_versioned):
"""Test redefining save version when it is already set"""
catalog = DataCatalog(datasets={"ds_versioned": dataset_versioned})
dataset_versioned._version = Version(
load="another_load_version.csv",
save="test_save_version.csv",
)
catalog.add("ds_same_versions", dataset_versioned)

assert (
catalog._load_versions["ds_same_versions"]
== dataset_versioned._version.load
)
assert catalog._load_versions["ds_versioned"] == "test_load_version.csv"
assert catalog._save_version == dataset_versioned._version.save

def test_redefine_save_version(self, correct_config, dataset_versioned):
"""Test redefining save version when it is already set"""
catalog = DataCatalog(datasets={"ds_versioned": dataset_versioned})
dataset_versioned._version = Version(
load="another_load_version.csv",
save="another_save_version.csv",
)
with pytest.raises(VersionAlreadyExistsError):
catalog.add("ds_same_versions", dataset_versioned)

def test_redefine_save_version_with_cached_dataset(
self, correct_config, cached_dataset_versioned
):
"""Test redefining load and save version with CachedDataset"""
catalog = DataCatalog.from_config(**correct_config)

# Redefining save version fails
with pytest.raises(VersionAlreadyExistsError):
catalog.add("cached_dataset_versioned", cached_dataset_versioned)

# Redefining load version passes
cached_dataset_versioned._dataset._version = Version(
load="test_load_version.csv", save=None
)
catalog.add("cached_dataset_versioned", cached_dataset_versioned)

assert (
catalog._load_versions["cached_dataset_versioned"]
== "test_load_version.csv"
)
assert catalog._save_version


class TestDataCatalogDatasetFactories:
def test_match_added_to_datasets_on_get(self, config_with_dataset_factories):
Expand Down
86 changes: 86 additions & 0 deletions tests/io/test_kedro_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from kedro.io.core import (
_DEFAULT_PACKAGES,
VERSION_FORMAT,
Version,
VersionAlreadyExistsError,
generate_timestamp,
parse_dataset_definition,
)
Expand Down Expand Up @@ -667,3 +669,87 @@ def test_load_version_on_unversioned_dataset(

with pytest.raises(DatasetError):
catalog.load("boats", version="first")

def test_redefine_save_version_via_catalog(
self, correct_config, dataset_versioned
):
"""Test redefining save version when it is already set"""
# Version is set automatically for the catalog
catalog = KedroDataCatalog.from_config(**correct_config)
with pytest.raises(VersionAlreadyExistsError):
catalog["ds_versioned"] = dataset_versioned

# Version is set manually for the catalog
correct_config["catalog"]["boats"]["versioned"] = True
catalog = KedroDataCatalog.from_config(**correct_config)
with pytest.raises(VersionAlreadyExistsError):
catalog["ds_versioned"] = dataset_versioned

def test_set_load_and_save_versions(self, correct_config, dataset_versioned):
"""Test setting load and save versions for catalog based on dataset's versions provided"""
catalog = KedroDataCatalog(datasets={"ds_versioned": dataset_versioned})

assert (
catalog._load_versions["ds_versioned"]
== dataset_versioned._version.load
)
assert catalog._save_version == dataset_versioned._version.save

def test_set_same_versions(self, correct_config, dataset_versioned):
"""Test setting the same load and save versions for catalog based on dataset's versions provided"""
catalog = KedroDataCatalog(datasets={"ds_versioned": dataset_versioned})
catalog["ds_same_versions"] = dataset_versioned

assert (
catalog._load_versions["ds_versioned"]
== dataset_versioned._version.load
)
assert catalog._save_version == dataset_versioned._version.save

def test_redefine_load_version(self, correct_config, dataset_versioned):
"""Test redefining save version when it is already set"""
catalog = KedroDataCatalog(datasets={"ds_versioned": dataset_versioned})
dataset_versioned._version = Version(
load="another_load_version.csv",
save="test_save_version.csv",
)
catalog["ds_same_versions"] = dataset_versioned

assert (
catalog._load_versions["ds_same_versions"]
== dataset_versioned._version.load
)
assert catalog._load_versions["ds_versioned"] == "test_load_version.csv"
assert catalog._save_version == dataset_versioned._version.save

def test_redefine_save_version(self, correct_config, dataset_versioned):
"""Test redefining save version when it is already set"""
catalog = KedroDataCatalog(datasets={"ds_versioned": dataset_versioned})
dataset_versioned._version = Version(
load="another_load_version.csv",
save="another_save_version.csv",
)
with pytest.raises(VersionAlreadyExistsError):
catalog["ds_same_versions"] = dataset_versioned

def test_redefine_save_version_with_cached_dataset(
self, correct_config, cached_dataset_versioned
):
"""Test redefining load and save version with CachedDataset"""
catalog = KedroDataCatalog.from_config(**correct_config)

# Redefining save version fails
with pytest.raises(VersionAlreadyExistsError):
catalog["cached_dataset_versioned"] = cached_dataset_versioned

# Redefining load version passes
cached_dataset_versioned._dataset._version = Version(
load="test_load_version.csv", save=None
)
catalog["cached_dataset_versioned"] = cached_dataset_versioned

assert (
catalog._load_versions["cached_dataset_versioned"]
== "test_load_version.csv"
)
assert catalog._save_version

0 comments on commit 7b24af7

Please sign in to comment.