Skip to content

Commit

Permalink
add concatenate datasets method (#148)
Browse files Browse the repository at this point in the history
* Added concatenate_datasets method in src/dataset.py

* enables concatenation of multiple pie-datasets

* tests still missing

* Added tests for concatenate_datasets method

* testing the core functionality with list/dic of Dataset/IterableDatasets
 through test_concatenate_datasets()

* testing occurring errors during concatenation through test_concatenate_datasets_errors()

* testing helper function _add_dset_name_to_document

* fix test method name

* improve docstring

---------

Co-authored-by: Arne Binder <[email protected]>
  • Loading branch information
kai-car and ArneBinder authored Aug 20, 2024
1 parent 8605432 commit b18e37f
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/pie_datasets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .builder import ArrowBasedBuilder, GeneratorBasedBuilder
from .dataset import Dataset, IterableDataset
from .dataset import Dataset, IterableDataset, concatenate_datasets
from .dataset_dict import DatasetDict, load_dataset

__all__ = [
Expand All @@ -9,4 +9,5 @@
"IterableDataset",
"DatasetDict",
"load_dataset",
"concatenate_datasets",
]
48 changes: 48 additions & 0 deletions src/pie_datasets/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,3 +662,51 @@ def get_pie_dataset_type(
raise TypeError(
f"the dataset must be of type Dataset or IterableDataset, but is of type {type(hf_dataset)}"
)


def _add_dset_name_to_document(doc: Document, name: str) -> Document:
if not hasattr(doc, "metadata"):
raise ValueError(
f"Document does not have metadata attribute which required to save the dataset name: {doc}"
)
if "dataset_name" in doc.metadata:
raise ValueError(
f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}"
)
doc.metadata["dataset_name"] = name
return doc


def concatenate_datasets(
dsets: Union[
List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset]
]
) -> Union[Dataset, IterableDataset]:
"""Concatenate multiple datasets into a single dataset. The datasets must have the same
document type.
Args:
dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If
a dictionary is provided, the dataset names will be added to the documents as metadata.
Returns:
A new dataset that is the concatenation of the input datasets.
"""

if isinstance(dsets, dict):
dsets = [
dset.map(_add_dset_name_to_document, fn_kwargs={"name": name})
for name, dset in dsets.items()
]

if len(dsets) == 0:
raise ValueError("No datasets to concatenate")

document_type = dsets[0].document_type
for doc in dsets[1:]:
if not doc.document_type == document_type:
raise ValueError("All datasets must have the same document type to concatenate")

result_hf = datasets.concatenate_datasets(dsets)
pie_dataset_type = get_pie_dataset_type(dsets[0])

return pie_dataset_type.from_hf_dataset(result_hf, document_type=document_type)
95 changes: 94 additions & 1 deletion tests/unit/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule

from pie_datasets import Dataset, IterableDataset
from pie_datasets.core.dataset import get_pie_dataset_type
from pie_datasets.core.dataset import (
_add_dset_name_to_document,
concatenate_datasets,
get_pie_dataset_type,
)
from tests.conftest import TestDocument
from tests.unit.core import TEST_PACKAGE

Expand Down Expand Up @@ -479,3 +483,92 @@ def _empty_docs():
with pytest.raises(ValueError) as excinfo:
dataset_class.from_documents(_empty_docs)
assert str(excinfo.value) == "No documents to create dataset from"


@pytest.mark.parametrize("as_list", [False, True])
def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list):
# Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets
if as_list:
# Test concatenation of list of datasets
concatenated_dataset = concatenate_datasets(
[
maybe_iterable_dataset["train"],
maybe_iterable_dataset["validation"],
maybe_iterable_dataset["test"],
]
)
else:
# Test concatenation of dictionary of datasets
concatenated_dataset = concatenate_datasets(maybe_iterable_dataset)

# Check correct output type
if isinstance(maybe_iterable_dataset["train"], IterableDataset):
# if input is IterableDataset, output should be IterableDataset
assert isinstance(concatenated_dataset, IterableDataset)
elif isinstance(maybe_iterable_dataset["train"], Dataset):
# if input is Dataset, output should be Dataset
assert isinstance(concatenated_dataset, Dataset)
else:
raise ValueError("Unexpected input type")

concatenated_dataset = list(concatenated_dataset)

for doc in concatenated_dataset:
assert isinstance(doc, TextBasedDocument)
if not as_list:
# If input is dictionary, check that dataset_name is added to metadata
assert doc.metadata["dataset_name"] is not None
assert doc.metadata["dataset_name"] in ["test", "train", "validation"]

assert len(concatenated_dataset) == 12

assert [concatenated_dataset[i].id for i in [0, 8, 10]] == [
"train_doc1",
"val_doc1",
"test_doc1",
]
assert [doc.id for doc in concatenated_dataset[7:11]] == [
"train_doc8",
"val_doc1",
"val_doc2",
"test_doc1",
]


def test_concatenate_datasets_errors(dataset_with_converter_functions):
# Test concatenation of empty datasets
empty_dataset = list[Dataset]()
with pytest.raises(ValueError) as excinfo:
concatenate_datasets(empty_dataset)
assert str(excinfo.value) == "No datasets to concatenate"

# Test concatenation of datasets with different document types
dataset_with_converted_doc = dataset_with_converter_functions.to_document_type(
TestDocumentWithLabel
)
with pytest.raises(ValueError) as excinfo:
concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc])
assert str(excinfo.value) == "All datasets must have the same document type to concatenate"


def test_add_dset_name_to_document():
# Test document having no metadata attribute
doc = Document()
assert not hasattr(doc, "metadata")
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
assert (
str(excinfo.value)
== "Document does not have metadata attribute which required to save the dataset name: Document()"
)

# Test adding dataset name to document
doc.metadata = {}
assert hasattr(doc, "metadata")
_add_dset_name_to_document(doc, "test_dataset_name")
assert doc.metadata["dataset_name"] == "test_dataset_name"

# Test document already having dataset_name in metadata
with pytest.raises(ValueError) as excinfo:
_add_dset_name_to_document(doc, "test")
assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name"

0 comments on commit b18e37f

Please sign in to comment.