diff --git a/graphstorm-processing/graphstorm_processing/constants.py b/graphstorm-processing/graphstorm_processing/constants.py index a732306ab8..1c1ce5282f 100644 --- a/graphstorm-processing/graphstorm_processing/constants.py +++ b/graphstorm-processing/graphstorm_processing/constants.py @@ -31,6 +31,7 @@ MAX_VALUE = "MAX_VALUE" COLUMN_NAME = "COLUMN_NAME" VALUE_COUNTS = "VALUE_COUNTS" +COLUMN_ORDER_FLAG = "label_property_order_id" ############## Spark-specific constants ##################### SPECIAL_CHARACTERS = {".", "+", "*", "?", "^", "$", "(", ")", "[", "]", "{", "}", "|", "\\"} diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py index 840654dcd8..a69c105614 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py @@ -20,6 +20,7 @@ from pyspark.sql import DataFrame, functions as F, SparkSession from pyspark.ml.feature import StringIndexer +from graphstorm_processing.constants import COLUMN_ORDER_FLAG from .base_dist_transformation import DistributedTransformation from . import DistMultiCategoryTransformation from ..spark_utils import safe_rename_column @@ -47,6 +48,7 @@ def apply(self, input_df: DataFrame) -> DataFrame: assert self.spark processed_col_name = self.label_column + "_processed" + input_df = input_df.withColumn(COLUMN_ORDER_FLAG, F.monotonically_increasing_id()) str_indexer = StringIndexer( inputCol=self.label_column, outputCol=processed_col_name, @@ -60,16 +62,20 @@ def apply(self, input_df: DataFrame) -> DataFrame: processed_col_name, self.label_column, ) + input_df = input_df.orderBy(COLUMN_ORDER_FLAG).drop(COLUMN_ORDER_FLAG) # Labels that were missing and were assigned the value numLabels by the StringIndexer # are converted to None - long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select( + long_class_label = indexed_df.select( + F.col(self.label_column).cast("long"), F.col(COLUMN_ORDER_FLAG) + ).select( F.when( F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore F.lit(None), ) .otherwise(F.col(self.label_column)) - .alias(self.label_column) + .alias(self.label_column), + F.col(COLUMN_ORDER_FLAG), ) # Get a mapping from original label to encoded value @@ -85,6 +91,7 @@ def apply(self, input_df: DataFrame) -> DataFrame: map_dict = json.loads(mapping_str) self.value_map[map_dict[self.label_column]] = map_dict[processed_col_name] + long_class_label = long_class_label.orderBy(COLUMN_ORDER_FLAG) return long_class_label @staticmethod diff --git a/graphstorm-processing/tests/test_dist_label_loader.py b/graphstorm-processing/tests/test_dist_label_loader.py index ce5e6b0fe7..4713474ecd 100644 --- a/graphstorm-processing/tests/test_dist_label_loader.py +++ b/graphstorm-processing/tests/test_dist_label_loader.py @@ -16,7 +16,7 @@ import numpy as np -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame, SparkSession, Row from pyspark.sql.types import StructField, StructType, StringType @@ -133,3 +133,34 @@ def test_dist_multilabel_classification(spark: SparkSession, check_df_schema): assert row_val[2 * i + 1] == 1.0 else: assert i == 4, "Only the last row should be None/null" + + +def test_dist_label_order(spark: SparkSession, check_df_schema): + label_col = "name" + classification_config = { + "column": "name", + "type": "classification", + "split_rate": {"train": 0.8, "val": 0.2, "test": 0.0}, + } + + data_zeros = [Row(value=0) for _ in range(5000)] + data_ones = [Row(value=1) for _ in range(5000)] + data = data_zeros + data_ones + names_df = spark.createDataFrame(data, schema=[label_col]) + + label_transformer = DistLabelLoader(LabelConfig(classification_config), spark) + + transformed_labels = label_transformer.process_label(names_df) + label_map = label_transformer.label_map + + assert set(label_map.keys()) == {"0", "1"} + + check_df_schema(transformed_labels) + + first_5000 = transformed_labels.limit(5000).collect() + first_5000_check = all(row.name == 0 for row in first_5000) + + next_5000 = transformed_labels.limit(10000).subtract(transformed_labels.limit(5000)).collect() + next_5000_check = all(row.name == 1 for row in next_5000) + + assert first_5000_check and next_5000_check, "The value assignment is in disorder"