Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pile Optimization #17

Merged
merged 13 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude working_dirs
- name: Test cpu-only metrics
run: |
python calculate_metrics.py --datasets test --schemes duped --models 70m
python calculate_metrics.py --datasets pile_test --schemes duped --models 70m
270 changes: 188 additions & 82 deletions calculate_metrics.py

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from .base import PIPELINE_SINGLETON as PIPELINE
# The import here determines the order of the pipeline
from .detokenize import detokenize
from .highly_duplicated_filter import sequence_duplicates_filter
from .token_frequency_statistics_filter import token_frequency_statistics_filter
from .pattern_incrementing import incrementing_sequences_filter
from .highly_repetitive import highly_repetitive_filter

_has_registered_all_filters = False

if not _has_registered_all_filters:
# The import here determines the order of the pipeline
from .detokenize import detokenize
from .highly_duplicated_filter import sequence_duplicates_filter
from .token_frequency_statistics_filter import token_frequency_statistics_filter
from .pattern_incrementing import incrementing_sequences_filter
from .highly_repetitive import highly_repetitive_filter

_has_registered_all_filters = True
40 changes: 39 additions & 1 deletion filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from filters.constants import PrecomputedFeatureName
from utils import initialize_logger
from spark.constants import NUM_PARTITIONS, SPARK_CACHE_DIR
from spark.constants import NUM_OUTPUT_PARTITIONS, SPARK_CACHE_DIR

FilterFunc: TypeAlias = Callable[..., Any]
PrecomputedFeatures: TypeAlias = Dict[PrecomputedFeatureName, DataFrame]
Expand All @@ -15,11 +15,21 @@

class MetricFilterPipeline:
def __init__(self):
"""
Pipeline for applying filters to a dataset.
"""
self.filters: List[FilterFunc] = []
self.features: PrecomputedFeatures = {}
self.spark: SparkSession

def register_filter(self) -> FilterFunc:
"""
Decorator for registering a filter function to the pipeline.

Returns:
FilterFunc: Decorated filter function
"""

def decorator(filter_func: FilterFunc) -> FilterFunc:
def wrapper(*args, **kwargs) -> Any:
return filter_func(*args, **kwargs)
Expand All @@ -32,17 +42,45 @@ def wrapper(*args, **kwargs) -> Any:
return decorator

def register_features(self, features: PrecomputedFeatures) -> None:
"""
Register precomputed features to the pipeline.

Args:
features (PrecomputedFeatures): Precomputed features

Returns:
None
"""
LOGGER.info(f"Registering features {features.keys()}...")
self.features.update(features)

def register_spark_session(self, spark: SparkSession) -> None:
"""
Register Spark session to the pipeline.

Args:
spark (SparkSession): Spark session

Returns:
None
"""
self.spark = spark

def transform(self, original: DataFrame) -> DataFrame:
"""
Apply all filters to the dataset.

Args:
original (DataFrame): Original dataset

Returns:
DataFrame: Filtered dataset
"""
current_dataset = original

for filter_func in self.filters:
# Checkpointing each filter to side-step potential OOM issues
LOGGER.info(f"Running filter {filter_func.__name__}...")
current_dataset: DataFrame = filter_func(current_dataset, self.features).checkpoint()

return current_dataset
Expand Down
1 change: 1 addition & 0 deletions filters/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class PrecomputedFeatureName(Enum):
SEQUENCE_FREQUENCIES = "sequence_frequencies"
MEMORIZED_TOKEN_FREQUENCIES = "memorized_token_frequencies"
NON_MEMORIZED_TOKEN_FREQUENCIES = "non_memorized_token_frequencies"
EMBEDDINGS = "embeddings"
3 changes: 2 additions & 1 deletion filters/detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

@PIPELINE_SINGLETON.register_filter()
def detokenize(dataset: DataFrame, _) -> DataFrame:
"""Detokenizes tokens into text as a preprocessing step.
"""
Detokenizes tokens into text as a preprocessing step.

Args:
dataset (DataFrame): Dataset containing sequences of tokens
Expand Down
7 changes: 4 additions & 3 deletions filters/highly_duplicated_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

@PIPELINE_SINGLETON.register_filter()
def sequence_duplicates_filter(dataset: DataFrame, features: PrecomputedFeatures) -> DataFrame:
"""Compute the number of duplicates (frequency) of a sequence.
"""
Compute the number of duplicates (frequency) of a sequence.

Args:
dataset (DataFrame): Dataset containing sequences of tokens
features (PrecomputedFeatures):
features (PrecomputedFeatures): Precomputed features

Returns:
DataFrame: Dataframe with additional columns of `sequence_duplicates`, number of times that
Expand All @@ -21,7 +22,7 @@ def sequence_duplicates_filter(dataset: DataFrame, features: PrecomputedFeatures
sequence_frequencies = features[PrecomputedFeatureName.SEQUENCE_FREQUENCIES].alias("sequence_frequencies")

# Join on `sequence_id` to extract the sequence frequency
final = main.join(sequence_frequencies, on="sequence_id", how="inner").select(
final = main.join(sequence_frequencies, on="sequence_id", how="left").select(
"main.*",
F.col("sequence_frequencies.frequency").alias("sequence_duplicates"),
)
Expand Down
80 changes: 55 additions & 25 deletions filters/highly_repetitive.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from typing import List, Tuple, Union

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T

from .base import PIPELINE_SINGLETON


def break_and_compare(ls: list, k: int) -> list:
def break_and_compare(ls: List, k: int) -> List:
"""
This function takes a list ls and an integer k as input and returns a list which is the first chunk of ls that is repeated k times. If no such chunk exists, it returns an empty list.

Parameters:

ls (list): The input list.
Args:
ls (List): The input list.
k (int): The integer value used for splitting and comparing the list.

Returns:
List: The first chunk of ls that is repeated k times. If no such chunk exists, it returns an empty list.
"""
n = len(ls)
while n % k != 0:
Expand Down Expand Up @@ -41,17 +44,19 @@ def break_and_compare(ls: list, k: int) -> list:
return []


def break_and_compare_wrapper(ls: list, start_k: int, end_k: int) -> list:
def break_and_compare_wrapper(ls: List, start_k: int, end_k: int) -> Union[Tuple[List, int, int], Tuple[List, int]]:
"""

This function serves as a wrapper for the `break_and_compare` function. It takes an additional two integer parameters `start_k` and `end_k` to define a range of values for `k`.
It iterates over this range and calls `break_and_compare` for each value of `k` within the range.

Parameters:
- `ls` (list): The input list.
- `start_k` (int): The starting value of `k` for the range (inclusive).
- `end_k` (int): The ending value of `k` for the range (inclusive).
Args:
ls (List): The input list.
start_k (int): The starting value of `k` for the range (inclusive).
end_k (int): The ending value of `k` for the range (inclusive).

Returns:
Union[Tuple[List, int, int], Tuple[List, int]]: A tuple containing the result of `break_and_compare` and the values of `i` and `k` for which the result was obtained.
"""
# end_k is inclusive
ls = list(ls)
Expand All @@ -74,14 +79,23 @@ def break_and_compare_wrapper(ls: list, start_k: int, end_k: int) -> list:
return result, i, k
result = break_and_compare(ls[i:], k)
if result:
return result, k
return result, i, k
result = break_and_compare(ls, k)
if result:
return result, 0, k
return [], 0, -1


def find_smallest_repeating_unit(lst):
def find_smallest_repeating_unit(lst) -> List:
"""
This function takes a list as input and returns the smallest repeating unit of the list. If no such unit exists, it returns the list itself.

Args:
lst (List): The input list.

Returns:
List: The smallest repeating unit of the list. If no such unit exists, it returns the list itself.
"""
if lst is None:
return []
n = len(lst)
Expand All @@ -94,15 +108,16 @@ def find_smallest_repeating_unit(lst):

# Check if the entire list can be formed by repeating the unit
if all(lst[i : i + unit_length] == unit for i in range(0, n, unit_length)):
return unit
return unit, n // unit_length

# If no repeating unit is found, the list itself is the smallest repeating unit
return lst
return lst, 1


@PIPELINE_SINGLETON.register_filter()
def highly_repetitive_filter(dataset: DataFrame, _) -> DataFrame:
"""Returns the repeating chunk and the number of times a sequence is repeating
"""
Returns the repeating chunk and the number of times a sequence is repeating.

Args:
dataset (DataFrame): Dataset containing sequences of tokens
Expand All @@ -111,32 +126,47 @@ def highly_repetitive_filter(dataset: DataFrame, _) -> DataFrame:
Outputs Include:
- `num_repeating`: Number of times a sequence is repeating
- `smallest_repeating_chunk`: Smallest repeating token sequence

Returns:
DataFrame: with additional column of `is_incrementing`
DataFrame: with additional columns
`repeating_chunk`: Repeating Chunk
`num_repeating`: Number of times the chunk is repeating
`repeating_offset`: Offset of repeating sequence
"""
main = dataset.alias("main")
repetitive_schema = T.StructType(
[
T.StructField("num_repeating", T.IntegerType()),
T.StructField("repeating_offset", T.IntegerType()),
T.StructField("repeating_chunk", T.ArrayType(T.LongType())),
T.StructField("repeating_offset", T.IntegerType()),
T.StructField("num_repeating", T.IntegerType())
]
)

start_k = 2
end_k = 5
repetitiveUDF = F.udf(lambda seq: break_and_compare_wrapper(seq, start_k, end_k), repetitive_schema)

smallest_repeating_chunk_schema = T.StructType(
[
T.StructField("smallest_repeating_chunk", T.ArrayType(T.LongType())),
T.StructField("num_times", T.IntegerType())
]
)
repetitiveUDF = F.udf(lambda seq: break_and_compare_wrapper(seq, 2, 5), repetitive_schema)
smallest_repeating_chunkUDF = F.udf(lambda seq: find_smallest_repeating_unit(seq), T.ArrayType(T.LongType()))
smallest_repeating_chunkUDF = F.udf(lambda seq: find_smallest_repeating_unit(seq), smallest_repeating_chunk_schema)

repetitive_counts = main.select("sequence_id", "text").withColumn("repetitive", repetitiveUDF("text"))
repetitive_counts = repetitive_counts.withColumn("smallest_repeating_chunk", smallest_repeating_chunkUDF("repetitive.repeating_chunk"))
repetitive_counts = main.select("sequence_id", "tokens").withColumn("repetitive", repetitiveUDF("tokens"))
repetitive_counts = repetitive_counts.withColumn("smallest_repeating", smallest_repeating_chunkUDF("repetitive.repeating_chunk"))

final = (
repetitive_counts.join(main, on="sequence_id", how="left")
main.join(repetitive_counts, on="sequence_id", how="left")
.drop(repetitive_counts.sequence_id)
.drop(repetitive_counts.text)
.drop(repetitive_counts.tokens)
.drop(repetitive_counts.repetitive.repeating_chunk)
.select(
"main.*",
"repetitive.*",
"smallest_repeating_chunk",
F.col("repetitive.repeating_offset").alias("repeating_offset"),
(F.col("repetitive.num_repeating")*F.col("smallest_repeating.num_times")).alias("num_repeating"),
F.col("smallest_repeating.smallest_repeating_chunk").alias("smallest_repeating_chunk")
)
)

Expand Down
28 changes: 24 additions & 4 deletions filters/pattern_incrementing.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
import re
import unicodedata

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T

from .base import PIPELINE_SINGLETON
import unicodedata
import re


def replace_non_numeric_with_whitespace(text: str) -> str:
"""
Replaces non-numeric characters with whitespace.

Args:
text (str): Sequence of tokens

Returns:
str: Sequence of tokens with non-numeric characters replaced with whitespace
"""
# Replace non-numeric characters with whitespace
# cleaned_text = re.sub(r'[^0-9]', ' ', text)
new_text = ""
Expand Down Expand Up @@ -48,6 +58,15 @@ def replace_non_numeric_with_whitespace(text: str) -> str:


def incrementing_sequences_filter_wrapper(text: str) -> bool:
"""
Returns if a sequence is incrementing.

Args:
text (str): Sequence of tokens

Returns:
bool: True if the sequence is incrementing, False otherwise
"""
# count number of numeric and non-numeric characters
num_numeric = 0
num_non_numeric = 0
Expand Down Expand Up @@ -317,7 +336,8 @@ def incrementing_sequences_filter_wrapper(text: str) -> bool:

@PIPELINE_SINGLETON.register_filter()
def incrementing_sequences_filter(dataset: DataFrame, _) -> DataFrame:
"""Returns if a sequence is incrementing
"""
Returns if a sequence is incrementing.

Args:
dataset (DataFrame): Dataset containing sequences of tokens
Expand All @@ -338,4 +358,4 @@ def incrementing_sequences_filter(dataset: DataFrame, _) -> DataFrame:
samp = r"""
"A.1 , A.2 , A.3 , A.4, B.1 , B.2, B.3, C.1"
"""
print(incrementing_sequences_filter(samp))
print(incrementing_sequences_filter_wrapper(samp))
10 changes: 5 additions & 5 deletions filters/token_frequency_statistics_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

@PIPELINE_SINGLETON.register_filter()
def token_frequency_statistics_filter(dataset: DataFrame, features: PrecomputedFeatures) -> DataFrame:
"""Compute token frequency statistics of a list of token frequencies ordered by the token index (not ID) in the sequence.
"""
Compute token frequency statistics of a list of token frequencies ordered by the token index (not ID) in the sequence.

Statistics include:
- `max_frequency`: maximum frequency of token frequencies in the sequence
Expand All @@ -28,8 +29,7 @@ def token_frequency_statistics_filter(dataset: DataFrame, features: PrecomputedF
memorized_frequencies = features[PrecomputedFeatureName.MEMORIZED_TOKEN_FREQUENCIES].alias("memorized")
non_memorized_frequencies = features[PrecomputedFeatureName.NON_MEMORIZED_TOKEN_FREQUENCIES].alias("non_memorized")

# First, we expand the token indices, then join to extract the frequencies
# Note that we dropped the memorization score, we'll re-join it later.
# First, we expand the token indices, then join to extract the frequencies.
flattened_main = main.select("sequence_id", F.posexplode("tokens").alias("token_index", "token_id"))
token_frequencies = (
flattened_main.join(memorized_frequencies, on="token_id", how="left")
Expand Down Expand Up @@ -69,6 +69,6 @@ def token_frequency_statistics_filter(dataset: DataFrame, features: PrecomputedF
F.transform(F.col("frequencies"), lambda x: x.frequency).alias("frequencies"),
).alias("filtered")

# Finally, re-attach the memorization score from the original dataset
final = filtered_frequencies.join(main, on="sequence_id", how="left").drop(filtered_frequencies.sequence_id).select("main.*", "filtered.*")
final = main.join(filtered_frequencies, on="sequence_id", how="left").drop(filtered_frequencies.sequence_id).select("main.*", "filtered.*")

return final
4 changes: 3 additions & 1 deletion spark/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
SPARK_CACHE_DIR = "spark_cache"
NUM_PARTITIONS = 1
NUM_CPU_COUNT = 64
NUM_OUTPUT_PARTITIONS = 1
NUM_SPARK_PARTITIONS = 4096
Loading
Loading