diff --git a/src/pie_datasets/core/__init__.py b/src/pie_datasets/core/__init__.py index e81adf7b..22f4bd7a 100644 --- a/src/pie_datasets/core/__init__.py +++ b/src/pie_datasets/core/__init__.py @@ -1,5 +1,5 @@ from .builder import ArrowBasedBuilder, GeneratorBasedBuilder -from .dataset import Dataset, IterableDataset +from .dataset import Dataset, IterableDataset, concatenate_datasets from .dataset_dict import DatasetDict, load_dataset __all__ = [ @@ -9,4 +9,5 @@ "IterableDataset", "DatasetDict", "load_dataset", + "concatenate_datasets", ] diff --git a/src/pie_datasets/core/dataset.py b/src/pie_datasets/core/dataset.py index d0d81ab7..60ae5c34 100644 --- a/src/pie_datasets/core/dataset.py +++ b/src/pie_datasets/core/dataset.py @@ -662,3 +662,51 @@ def get_pie_dataset_type( raise TypeError( f"the dataset must be of type Dataset or IterableDataset, but is of type {type(hf_dataset)}" ) + + +def _add_dset_name_to_document(doc: Document, name: str) -> Document: + if not hasattr(doc, "metadata"): + raise ValueError( + f"Document does not have metadata attribute which required to save the dataset name: {doc}" + ) + if "dataset_name" in doc.metadata: + raise ValueError( + f"Document already has a dataset_name attribute: {doc.metadata['dataset_name']}" + ) + doc.metadata["dataset_name"] = name + return doc + + +def concatenate_datasets( + dsets: Union[ + List[Dataset], List[IterableDataset], Dict[str, Dataset], Dict[str, IterableDataset] + ] +) -> Union[Dataset, IterableDataset]: + """Concatenate multiple datasets into a single dataset. The datasets must have the same + document type. + + Args: + dsets: A list of datasets or a dictionary with dataset names as keys and datasets as values. If + a dictionary is provided, the dataset names will be added to the documents as metadata. + Returns: + A new dataset that is the concatenation of the input datasets. + """ + + if isinstance(dsets, dict): + dsets = [ + dset.map(_add_dset_name_to_document, fn_kwargs={"name": name}) + for name, dset in dsets.items() + ] + + if len(dsets) == 0: + raise ValueError("No datasets to concatenate") + + document_type = dsets[0].document_type + for doc in dsets[1:]: + if not doc.document_type == document_type: + raise ValueError("All datasets must have the same document type to concatenate") + + result_hf = datasets.concatenate_datasets(dsets) + pie_dataset_type = get_pie_dataset_type(dsets[0]) + + return pie_dataset_type.from_hf_dataset(result_hf, document_type=document_type) diff --git a/tests/unit/core/test_dataset.py b/tests/unit/core/test_dataset.py index 789b5d3a..16532823 100644 --- a/tests/unit/core/test_dataset.py +++ b/tests/unit/core/test_dataset.py @@ -18,7 +18,11 @@ from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule from pie_datasets import Dataset, IterableDataset -from pie_datasets.core.dataset import get_pie_dataset_type +from pie_datasets.core.dataset import ( + _add_dset_name_to_document, + concatenate_datasets, + get_pie_dataset_type, +) from tests.conftest import TestDocument from tests.unit.core import TEST_PACKAGE @@ -479,3 +483,92 @@ def _empty_docs(): with pytest.raises(ValueError) as excinfo: dataset_class.from_documents(_empty_docs) assert str(excinfo.value) == "No documents to create dataset from" + + +@pytest.mark.parametrize("as_list", [False, True]) +def test_concatenate_datasets(maybe_iterable_dataset, dataset_with_converter_functions, as_list): + # Tests four different cases of concatenation of list/dict of Datasets/IterableDatasets + if as_list: + # Test concatenation of list of datasets + concatenated_dataset = concatenate_datasets( + [ + maybe_iterable_dataset["train"], + maybe_iterable_dataset["validation"], + maybe_iterable_dataset["test"], + ] + ) + else: + # Test concatenation of dictionary of datasets + concatenated_dataset = concatenate_datasets(maybe_iterable_dataset) + + # Check correct output type + if isinstance(maybe_iterable_dataset["train"], IterableDataset): + # if input is IterableDataset, output should be IterableDataset + assert isinstance(concatenated_dataset, IterableDataset) + elif isinstance(maybe_iterable_dataset["train"], Dataset): + # if input is Dataset, output should be Dataset + assert isinstance(concatenated_dataset, Dataset) + else: + raise ValueError("Unexpected input type") + + concatenated_dataset = list(concatenated_dataset) + + for doc in concatenated_dataset: + assert isinstance(doc, TextBasedDocument) + if not as_list: + # If input is dictionary, check that dataset_name is added to metadata + assert doc.metadata["dataset_name"] is not None + assert doc.metadata["dataset_name"] in ["test", "train", "validation"] + + assert len(concatenated_dataset) == 12 + + assert [concatenated_dataset[i].id for i in [0, 8, 10]] == [ + "train_doc1", + "val_doc1", + "test_doc1", + ] + assert [doc.id for doc in concatenated_dataset[7:11]] == [ + "train_doc8", + "val_doc1", + "val_doc2", + "test_doc1", + ] + + +def test_concatenate_datasets_errors(dataset_with_converter_functions): + # Test concatenation of empty datasets + empty_dataset = list[Dataset]() + with pytest.raises(ValueError) as excinfo: + concatenate_datasets(empty_dataset) + assert str(excinfo.value) == "No datasets to concatenate" + + # Test concatenation of datasets with different document types + dataset_with_converted_doc = dataset_with_converter_functions.to_document_type( + TestDocumentWithLabel + ) + with pytest.raises(ValueError) as excinfo: + concatenate_datasets([dataset_with_converter_functions, dataset_with_converted_doc]) + assert str(excinfo.value) == "All datasets must have the same document type to concatenate" + + +def test_add_dset_name_to_document(): + # Test document having no metadata attribute + doc = Document() + assert not hasattr(doc, "metadata") + with pytest.raises(ValueError) as excinfo: + _add_dset_name_to_document(doc, "test") + assert ( + str(excinfo.value) + == "Document does not have metadata attribute which required to save the dataset name: Document()" + ) + + # Test adding dataset name to document + doc.metadata = {} + assert hasattr(doc, "metadata") + _add_dset_name_to_document(doc, "test_dataset_name") + assert doc.metadata["dataset_name"] == "test_dataset_name" + + # Test document already having dataset_name in metadata + with pytest.raises(ValueError) as excinfo: + _add_dset_name_to_document(doc, "test") + assert str(excinfo.value) == "Document already has a dataset_name attribute: test_dataset_name"