Skip to content

Commit

Permalink
Address review comments, include transformation name in JSON rep
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed May 29, 2024
1 parent 2b6f7fd commit 42c3163
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DistCategoryTransformation(DistributedTransformation):
Transforms categorical features into a vector of one-hot-encoded values.
"""

def __init__(self, cols: List[str], spark: SparkSession) -> None:
def __init__(self, cols: list[str], spark: SparkSession) -> None:
super().__init__(cols, spark)

@staticmethod
Expand Down Expand Up @@ -160,6 +160,7 @@ def apply(self, input_df: DataFrame) -> DataFrame:
"string_indexer_labels_array": str_indexer_model.labelsArray,
"cols": self.cols,
"per_col_label_to_one_hot_idx": per_col_label_to_one_hot_idx,
"transformation_name": self.get_transformation_name(),
}

return dense_vector_features
Expand All @@ -181,6 +182,8 @@ def get_json_representation(self) -> dict:
list[str], with num_cols elements
per_col_label_to_one_hot_idx:
dict[str, dict[str, int]], with num_cols elements, each with num_categories elements
transformation_name:
str, will be 'DistCategoryTransformation'
"""
return self.json_representation

Expand Down Expand Up @@ -212,7 +215,7 @@ def __init__(self, cols: Sequence[str], separator: str) -> None:
if self.separator in SPECIAL_CHARACTERS:
self.separator = f"\\{self.separator}"

self.value_map = {} # type: Dict[str, int]
self.value_map: dict[str, int] = {}

@staticmethod
def get_transformation_name() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def test_multiple_single_cat_cols_json(user_df, spark):
labels_array = multi_cols_rep["string_indexer_labels_array"]
one_hot_index_for_string = multi_cols_rep["per_col_label_to_one_hot_idx"]
cols = multi_cols_rep["cols"]
name = multi_cols_rep["transformation_name"]

assert name == "DistCategoryTransformation"

# The Spark-generated and our own one-hot-index mappings should match
for col_labels, col in zip(labels_array, cols):
Expand Down

0 comments on commit 42c3163

Please sign in to comment.