Skip to content

Commit

Permalink
add parameter set_batch_size_to_split_size to DatasetDict.map (#155)
Browse files Browse the repository at this point in the history
* implement set_batch_size_to_split_size

* improve docstring

* add test
  • Loading branch information
ArneBinder authored Sep 30, 2024
1 parent ef8f2d7 commit 8fc0cc9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/pie_datasets/core/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def map( # type: ignore
self,
function: Optional[Union[Callable, str]] = None,
result_document_type: Optional[Union[str, Type[Document]]] = None,
set_batch_size_to_split_size: bool = False,
**kwargs,
) -> "DatasetDict":
"""Applies a function to all documents in the dataset.
Expand All @@ -370,6 +371,9 @@ def map( # type: ignore
string that can be resolved to such a type. If not provided, it is tried to infer it from the
function signature. If this is not possible, the document type of the input dataset
is used.
set_batch_size_to_split_size: If enabled, set the batch_size to the size of the respective split
when calling map() on it. This is useful to transform whole splits when using it in
combination with batched=True.
**kwargs: additional keyword arguments for `datasets.Dataset.map()`
"""

Expand All @@ -395,6 +399,8 @@ def identity(x):
for split, dataset in self.items():
if isinstance(func, EnterDatasetMixin):
func.enter_dataset(dataset=dataset, name=split)
if set_batch_size_to_split_size:
map_kwargs["batch_size"] = len(dataset)
result_dict[split] = dataset.map(**map_kwargs)
if isinstance(func, ExitDatasetMixin):
func.exit_dataset(dataset=result_dict[split], name=split)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/core/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,25 @@ def exit_dataset_dict(self, dataset_dict: DatasetDict) -> None:
assert doc1 == doc2


def test_map_set_max_batch_size(dataset_dict):
def join_docs(docs):
return [TextBasedDocument(text=" ".join([doc.text for doc in docs]))]

dataset_dict_mapped = dataset_dict.map(
join_docs,
batched=True,
set_batch_size_to_split_size=True,
result_document_type=TextBasedDocument,
)
assert dataset_dict_mapped.document_type is TextBasedDocument
for split in dataset_dict:
assert len(dataset_dict_mapped[split]) == 1
new_doc = dataset_dict_mapped[split][0]
assert isinstance(new_doc, TextBasedDocument)
original_texts = [doc.text for doc in dataset_dict[split]]
assert new_doc.text == " ".join(original_texts)


def test_select(dataset_dict):
# select documents by index
dataset_dict_selected = dataset_dict.select(
Expand Down

0 comments on commit 8fc0cc9

Please sign in to comment.