diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py index 840654dcd8..6481f6c36c 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py @@ -19,6 +19,7 @@ from pyspark.sql import DataFrame, functions as F, SparkSession from pyspark.ml.feature import StringIndexer +from pyspark.sql.functions import monotonically_increasing_id from .base_dist_transformation import DistributedTransformation from . import DistMultiCategoryTransformation @@ -47,6 +48,8 @@ def apply(self, input_df: DataFrame) -> DataFrame: assert self.spark processed_col_name = self.label_column + "_processed" + input_df = input_df.withColumn("unique_id", monotonically_increasing_id()) + str_indexer = StringIndexer( inputCol=self.label_column, outputCol=processed_col_name, @@ -60,16 +63,17 @@ def apply(self, input_df: DataFrame) -> DataFrame: processed_col_name, self.label_column, ) - # Labels that were missing and were assigned the value numLabels by the StringIndexer # are converted to None - long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select( + long_class_label = indexed_df.select(F.col(self.label_column).cast("long"), + F.col("unique_id")).select( F.when( F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore F.lit(None), ) .otherwise(F.col(self.label_column)) - .alias(self.label_column) + .alias(self.label_column), + F.col("unique_id") ) # Get a mapping from original label to encoded value @@ -85,6 +89,7 @@ def apply(self, input_df: DataFrame) -> DataFrame: map_dict = json.loads(mapping_str) self.value_map[map_dict[self.label_column]] = map_dict[processed_col_name] + long_class_label = long_class_label.orderBy("unique_id").drop("unique_id") return long_class_label @staticmethod