Skip to content

Commit

Permalink
label value assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 21, 2024
1 parent c964cc4 commit 4f3b80a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
1 change: 1 addition & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MAX_VALUE = "MAX_VALUE"
COLUMN_NAME = "COLUMN_NAME"
VALUE_COUNTS = "VALUE_COUNTS"
COLUMN_ORDER_FLAG = "label_property_order_id"

############## Spark-specific constants #####################
SPECIAL_CHARACTERS = {".", "+", "*", "?", "^", "$", "(", ")", "[", "]", "{", "}", "|", "\\"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.ml.feature import StringIndexer

from graphstorm_processing.constants import COLUMN_ORDER_FLAG
from .base_dist_transformation import DistributedTransformation
from . import DistMultiCategoryTransformation
from ..spark_utils import safe_rename_column
Expand Down Expand Up @@ -47,6 +48,7 @@ def apply(self, input_df: DataFrame) -> DataFrame:
assert self.spark
processed_col_name = self.label_column + "_processed"

input_df = input_df.withColumn(COLUMN_ORDER_FLAG, F.monotonically_increasing_id())
str_indexer = StringIndexer(
inputCol=self.label_column,
outputCol=processed_col_name,
Expand All @@ -63,13 +65,16 @@ def apply(self, input_df: DataFrame) -> DataFrame:

# Labels that were missing and were assigned the value numLabels by the StringIndexer
# are converted to None
long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select(
long_class_label = indexed_df.select(
F.col(self.label_column).cast("long"), F.col(COLUMN_ORDER_FLAG)
).select(
F.when(
F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore
F.lit(None),
)
.otherwise(F.col(self.label_column))
.alias(self.label_column)
.alias(self.label_column),
F.col(COLUMN_ORDER_FLAG),
)

# Get a mapping from original label to encoded value
Expand All @@ -85,6 +90,7 @@ def apply(self, input_df: DataFrame) -> DataFrame:
map_dict = json.loads(mapping_str)
self.value_map[map_dict[self.label_column]] = map_dict[processed_col_name]

long_class_label = long_class_label.orderBy(COLUMN_ORDER_FLAG)
return long_class_label

@staticmethod
Expand Down

0 comments on commit 4f3b80a

Please sign in to comment.