Skip to content

Commit

Permalink
Simplify ReadFromCroissant by removing the pipeline argument and ma…
Browse files Browse the repository at this point in the history
…king it a PCollection. (#780)

…
  • Loading branch information
marcenacp authored Dec 4, 2024
1 parent 39210cc commit be6daee
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 52 deletions.
34 changes: 3 additions & 31 deletions python/mlcroissant/mlcroissant/_src/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,22 @@
from __future__ import annotations

from collections.abc import Mapping
import functools
import typing
from typing import Any, Callable
from typing import Any

from etils import epath

from mlcroissant._src.datasets import Dataset
from mlcroissant._src.datasets import Filters

if typing.TYPE_CHECKING:
import apache_beam as beam


def _beam_ptransform_fn(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Lazy version of `@beam.ptransform_fn` in case Beam is not installed."""
lazy_decorated_fn = None

@functools.wraps(fn)
def decorated(*args, **kwargs):
nonlocal lazy_decorated_fn
# Actually decorate the function only the first time it is called
if lazy_decorated_fn is None:
import apache_beam as beam

lazy_decorated_fn = beam.ptransform_fn(fn)
return lazy_decorated_fn(*args, **kwargs)

return decorated


@_beam_ptransform_fn
def ReadFromCroissant(
pipeline: beam.Pipeline,
*,
jsonld: epath.PathLike | Mapping[str, Any],
record_set: str,
mapping: Mapping[str, epath.PathLike] | None = None,
filters: Filters | None = None,
):
"""Returns an Apache Beam reader to generate the dataset using e.g. Spark.
"""Returns an Apache Beam PCollection to generate the dataset using e.g. Spark.
Example of usage:
Expand All @@ -65,7 +41,6 @@ def ReadFromCroissant(
Face datasets, so it raises an error if the dataset is not a Hugging Face dataset.
Args:
pipeline: A Beam pipeline (automatically set).
jsonld: A JSON object or a path to a Croissant file (URL, str or pathlib.Path).
record_set: The name of the record set to generate.
mapping: Mapping filename->filepath as a Python dict[str, str] to handle manual
Expand All @@ -85,7 +60,4 @@ def ReadFromCroissant(
A ValueError if the dataset is not streamable.
"""
dataset = Dataset(jsonld=jsonld, mapping=mapping)
return dataset.records(record_set, filters=filters).beam_reader(
pipeline,
filters=filters,
)
return dataset.records(record_set, filters=filters).beam_reader()
11 changes: 2 additions & 9 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from collections.abc import Mapping
import dataclasses
import typing
from typing import Any

from absl import logging
Expand All @@ -29,9 +28,6 @@
from mlcroissant._src.structure_graph.nodes.metadata import Metadata
from mlcroissant._src.structure_graph.nodes.source import FileProperty

if typing.TYPE_CHECKING:
import apache_beam as beam

Filters = Mapping[str, Any]


Expand Down Expand Up @@ -176,17 +172,14 @@ def __iter__(self):
record_set=self.record_set, operations=operations
)

def beam_reader(
self, pipeline: beam.Pipeline, filters: Mapping[str, Any] | None = None
):
def beam_reader(self):
"""See ReadFromCroissant docstring."""
operations = self._filter_interesting_operations(self.filters)
execute_downloads(operations)
return execute_operations_in_beam(
pipeline=pipeline,
record_set=self.record_set,
operations=operations,
filters=filters or self.filters,
filters=self.filters,
)

def _filter_interesting_operations(self, filters: Filters | None) -> Operations:
Expand Down
15 changes: 3 additions & 12 deletions python/mlcroissant/mlcroissant/_src/operation_graph/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import functools
import json
import sys
import typing
from typing import Any, Generator

from absl import logging
Expand All @@ -21,9 +20,6 @@
from mlcroissant._src.operation_graph.operations.download import Download
from mlcroissant._src.operation_graph.operations.read import Read

if typing.TYPE_CHECKING:
import apache_beam as beam

ElementWithIndex = tuple[int, Any]


Expand Down Expand Up @@ -129,7 +125,6 @@ def read_all_files():


def execute_operations_in_beam(
pipeline: beam.Pipeline,
record_set: str,
operations: Operations,
filters: Mapping[str, Any] | None = None,
Expand Down Expand Up @@ -181,19 +176,15 @@ def execute_operations_in_beam(
for operation in operations_in_memory:
# If there is no FilterFiles, we return the PCollection without parallelization.
if operation == target:
return (
pipeline
| beam.Create([(0, *operation.inputs)])
| _beam_operation_with_index(operation, sys.maxsize, stage_prefix)
return beam.Create([(0, *operation.inputs)]) | _beam_operation_with_index(
operation, sys.maxsize, stage_prefix
)
else:
operation(set_output_in_memory=True)

files = filter_files.output # even for large datasets, this can be handled in RAM.
# We first shard by file and assign a shard_index.
pipeline = pipeline | f"{stage_prefix} Shard by files with index" >> beam.Create(
enumerate(files)
)
pipeline = beam.Create(enumerate(files))
num_shards = len(files)
if not num_shards:
raise ValueError(
Expand Down

0 comments on commit be6daee

Please sign in to comment.