This is a simple library for creating readable dataset pipelines and reusing best practices for issues such as imbalanced datasets. There are just two components to keep track of: Dataset
and Datastream
.
Dataset
is a simple mapping between an index and an example. It provides pipelining of functions in a readable syntax originally adapted from tensorflow 2's tf.data.Dataset
.
Datastream
combines a Dataset
and a sampler into a stream of examples. It provides a simple solution to oversampling / stratification, weighted sampling, and finally converting to a torch.utils.data.DataLoader
.
See the documentation for more information.
poetry add pytorch-datastream
Or, for the old-timers:
pip install pytorch-datastream
The list below is meant to showcase functions that are useful in most standard and non-standard cases. It is not meant to be an exhaustive list. See the documentation for a more extensive list on API and usage.
Dataset.from_subscriptable
Dataset.from_dataframe
Dataset
.map
.subset
.split
.cache
.with_columns
Datastream.merge
Datastream.zip
Datastream
.map
.data*loader
.zip_index
.update_weights*
.update*example_weight*
.weight
.state_dict
.load_state_dict
Here's a basic example of loading images from a directory:
from datastream import Dataset
from pathlib import Path
from PIL import Image
# Assuming images are in a directory structure like:
# images/
# class1/
# image1.jpg
# image2.jpg
# class2/
# image3.jpg
# image4.jpg
image_dir = Path("images")
image_paths = list(image_dir.glob("\*_/_.jpg"))
dataset = (
Dataset.from_paths(
image_paths,
pattern=r".\*/(?P<class_name>\w+)/(?P<image_name>\w+).jpg"
)
.map(lambda row: dict(
image=Image.open(row["path"]),
class_name=row["class_name"],
image_name=row["image_name"],
))
)
# Access an item from the dataset
first_item = dataset[0]
print(f"Class: {first_item['class_name']}, Image name: {first_item['image_name']}")
The fruit datastreams given below repeatedly yields the string of its fruit type.
>>> datastream = Datastream.merge([
>>> ... (apple_datastream, 2),
>>> ... (pear_datastream, 1),
>>> ... (banana_datastream, 1),
>>> ... ])
>>> next(iter(datastream.data_loader(batch_size=8)))
>>> ['apple', 'apple', 'pear', 'banana', 'apple', 'apple', 'pear', 'banana']
>>>
The fruit datastreams given below repeatedly yields the string of its fruit type.
>>> datastream = Datastream.zip([
>>> ... apple_datastream,
>>> ... Datastream.merge([pear_datastream, banana_datastream]),
>>> ... ])
>>> next(iter(datastream.data_loader(batch_size=4)))
>>> [('apple', 'pear'), ('apple', 'banana'), ('apple', 'pear'), ('apple', 'banana')]
>>>
See the documentation for more usage examples.