Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 21, 2024
1 parent 4f3b80a commit 1c794f7
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion graphstorm-processing/tests/test_dist_label_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import DataFrame, SparkSession, Row
from pyspark.sql.types import StructField, StructType, StringType


Expand Down Expand Up @@ -133,3 +133,34 @@ def test_dist_multilabel_classification(spark: SparkSession, check_df_schema):
assert row_val[2 * i + 1] == 1.0
else:
assert i == 4, "Only the last row should be None/null"


def test_dist_label_order(spark: SparkSession, check_df_schema):
label_col = "name"
classification_config = {
"column": "name",
"type": "classification",
"split_rate": {"train": 0.8, "val": 0.2, "test": 0.0},
}

data_zeros = [Row(value=0) for _ in range(5000)]
data_ones = [Row(value=1) for _ in range(5000)]
data = data_zeros + data_ones
names_df = spark.createDataFrame(data, schema=[label_col])

label_transformer = DistLabelLoader(LabelConfig(classification_config), spark)

transformed_labels = label_transformer.process_label(names_df)
label_map = label_transformer.label_map

assert set(label_map.keys()) == {"0", "1"}

check_df_schema(transformed_labels)

first_5000 = transformed_labels.limit(5000).collect()
first_5000_check = all(row.name == 0 for row in first_5000)

next_5000 = transformed_labels.limit(10000).subtract(transformed_labels.limit(5000)).collect()
next_5000_check = all(row.name == 1 for row in next_5000)

assert first_5000_check and next_5000_check, "The value assignment is in disorder"

0 comments on commit 1c794f7

Please sign in to comment.