Skip to content

Commit

Permalink
hot fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 21, 2024
1 parent c964cc4 commit 3954920
Showing 1 changed file with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.ml.feature import StringIndexer
from pyspark.sql.functions import monotonically_increasing_id

from .base_dist_transformation import DistributedTransformation
from . import DistMultiCategoryTransformation
Expand Down Expand Up @@ -47,6 +48,8 @@ def apply(self, input_df: DataFrame) -> DataFrame:
assert self.spark
processed_col_name = self.label_column + "_processed"

input_df = input_df.withColumn("unique_id", monotonically_increasing_id())

str_indexer = StringIndexer(
inputCol=self.label_column,
outputCol=processed_col_name,
Expand All @@ -60,16 +63,17 @@ def apply(self, input_df: DataFrame) -> DataFrame:
processed_col_name,
self.label_column,
)

# Labels that were missing and were assigned the value numLabels by the StringIndexer
# are converted to None
long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select(
long_class_label = indexed_df.select(F.col(self.label_column).cast("long"),
F.col("unique_id")).select(
F.when(
F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore
F.lit(None),
)
.otherwise(F.col(self.label_column))
.alias(self.label_column)
.alias(self.label_column),
F.col("unique_id")
)

# Get a mapping from original label to encoded value
Expand All @@ -85,6 +89,7 @@ def apply(self, input_df: DataFrame) -> DataFrame:
map_dict = json.loads(mapping_str)
self.value_map[map_dict[self.label_column]] = map_dict[processed_col_name]

long_class_label = long_class_label.orderBy("unique_id").drop("unique_id")
return long_class_label

@staticmethod
Expand Down

0 comments on commit 3954920

Please sign in to comment.