Skip to content

Commit

Permalink
Added tests for concatenate_datasets method
Browse files Browse the repository at this point in the history
* testing the core functionality with list/dic of Dataset/IterableDatasets
 through test_concatenate_datasets()

* testing occurring errors during concatenation through test_concatenate_datasets_errors()

* testing helper function _add_dset_name_to_document
  • Loading branch information
kai-car committed Aug 9, 2024
1 parent dd873a8 commit 8bd4735
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def _add_dset_name_to_document(doc: Document, name: str) -> Document:
)
if "dataset_name" in doc.metadata:
raise ValueError(
f"Document already has a dataset_name attribute: {doc.metadata['dataset']}"
f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}"
)
doc.metadata["dataset_name"] = name
return doc
Expand Down
96 changes: 95 additions & 1 deletion tests/unit/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy
import pytest
import torch
from pytorch_ie import Document
from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span
from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.core.taskmodule import (
Expand All @@ -17,7 +18,11 @@
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule

from pie_datasets import Dataset, IterableDataset
from pie_datasets.core.dataset import get_pie_dataset_type
from pie_datasets.core.dataset import (
_add_dset_name_to_document,
concatenate_datasets,
get_pie_dataset_type,
)
from tests.conftest import TestDocument
from tests.unit.core import TEST_PACKAGE

Expand Down Expand Up @@ -431,3 +436,92 @@ def test_dataset_with_taskmodule(

for document in train_dataset:
assert not document["entities"].predictions


@pytest.mark.parametrize("as_list", [False, True])
def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list):
# Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets
if as_list:
# Test concatenation of list of datasets
concatenated_dataset = concatenate_datasets(
[
maybe_iterable_dataset["train"],
maybe_iterable_dataset["validation"],
maybe_iterable_dataset["test"],
]
)
else:
# Test concatenation of dictionary of datasets
concatenated_dataset = concatenate_datasets(maybe_iterable_dataset)

# Check correct output type
if isinstance(maybe_iterable_dataset["train"], IterableDataset):
# if input is IterableDataset, output should be IterableDataset
assert isinstance(concatenated_dataset, IterableDataset)
elif isinstance(maybe_iterable_dataset["train"], Dataset):
# if input is Dataset, output should be Dataset
assert isinstance(concatenated_dataset, Dataset)
else:
raise ValueError("Unexpected input type")

concatenated_dataset = list(concatenated_dataset)

for doc in concatenated_dataset:
assert isinstance(doc, TextBasedDocument)
if not as_list:
# If input is dictionary, check that dataset_name is added to metadata
assert doc.metadata["dataset_name"] is not None
assert doc.metadata["dataset_name"] in ["test", "train", "validation"]

assert len(concatenated_dataset) == 12

assert [concatenated_dataset[i].id for i in [0, 8, 10]] == [
"train_doc1",
"val_doc1",
"test_doc1",
]
assert [doc.id for doc in concatenated_dataset[7:11]] == [
"train_doc8",
"val_doc1",
"val_doc2",
"test_doc1",
]


def test_concatenate_datasets_errors(dataset_with_converter_functions):
# Test concatenation of empty datasets
empty_dataset = list[Dataset]()
with pytest.raises(ValueError) as excinfo:
concatenate_datasets(empty_dataset)
assert str(excinfo.value) == "No datasets to concatenate"

# Test concatenation of datasets with different document types
dataset_with_converted_doc = dataset_with_converter_functions.to_document_type(
TestDocumentWithLabel
)
with pytest.raises(ValueError) as excinfo:
concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc])
assert str(excinfo.value) == "All datasets must have the same document type to concatenate"


def test_add_set_name_to_document():
# Test document having no metadata attribute
doc = Document()
assert not hasattr(doc, "metadata")
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
assert (
str(excinfo.value)
== "Document does not have metadata attribute which required to save the dataset name: Document()"
)

# Test adding dataset name to document
doc.metadata = {}
assert hasattr(doc, "metadata")
_add_dset_name_to_document(doc, "test_dataset_name")
assert doc.metadata["dataset_name"] == "test_dataset_name"

# Test document already having dataset_name in metadata
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name"

0 comments on commit 8bd4735

Please sign in to comment.