From 8fc0cc9e073b45c6ba2152822da88a34f8efb300 Mon Sep 17 00:00:00 2001 From: ArneBinder Date: Mon, 30 Sep 2024 17:01:51 +0200 Subject: [PATCH] add parameter `set_batch_size_to_split_size` to `DatasetDict.map` (#155) * implement set_batch_size_to_split_size * improve docstring * add test --- src/pie_datasets/core/dataset_dict.py | 6 ++++++ tests/unit/core/test_dataset_dict.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/pie_datasets/core/dataset_dict.py b/src/pie_datasets/core/dataset_dict.py index 608a3d7c..f3020aeb 100644 --- a/src/pie_datasets/core/dataset_dict.py +++ b/src/pie_datasets/core/dataset_dict.py @@ -348,6 +348,7 @@ def map( # type: ignore self, function: Optional[Union[Callable, str]] = None, result_document_type: Optional[Union[str, Type[Document]]] = None, + set_batch_size_to_split_size: bool = False, **kwargs, ) -> "DatasetDict": """Applies a function to all documents in the dataset. @@ -370,6 +371,9 @@ def map( # type: ignore string that can be resolved to such a type. If not provided, it is tried to infer it from the function signature. If this is not possible, the document type of the input dataset is used. + set_batch_size_to_split_size: If enabled, set the batch_size to the size of the respective split + when calling map() on it. This is useful to transform whole splits when using it in + combination with batched=True. **kwargs: additional keyword arguments for `datasets.Dataset.map()` """ @@ -395,6 +399,8 @@ def identity(x): for split, dataset in self.items(): if isinstance(func, EnterDatasetMixin): func.enter_dataset(dataset=dataset, name=split) + if set_batch_size_to_split_size: + map_kwargs["batch_size"] = len(dataset) result_dict[split] = dataset.map(**map_kwargs) if isinstance(func, ExitDatasetMixin): func.exit_dataset(dataset=result_dict[split], name=split) diff --git a/tests/unit/core/test_dataset_dict.py b/tests/unit/core/test_dataset_dict.py index 3c66cf42..319b721a 100644 --- a/tests/unit/core/test_dataset_dict.py +++ b/tests/unit/core/test_dataset_dict.py @@ -280,6 +280,25 @@ def exit_dataset_dict(self, dataset_dict: DatasetDict) -> None: assert doc1 == doc2 +def test_map_set_max_batch_size(dataset_dict): + def join_docs(docs): + return [TextBasedDocument(text=" ".join([doc.text for doc in docs]))] + + dataset_dict_mapped = dataset_dict.map( + join_docs, + batched=True, + set_batch_size_to_split_size=True, + result_document_type=TextBasedDocument, + ) + assert dataset_dict_mapped.document_type is TextBasedDocument + for split in dataset_dict: + assert len(dataset_dict_mapped[split]) == 1 + new_doc = dataset_dict_mapped[split][0] + assert isinstance(new_doc, TextBasedDocument) + original_texts = [doc.text for doc in dataset_dict[split]] + assert new_doc.text == " ".join(original_texts) + + def test_select(dataset_dict): # select documents by index dataset_dict_selected = dataset_dict.select(