Skip to content

Commit

Permalink
[GSProcessing] Add ability to re-apply pre-computed categorical trans…
Browse files Browse the repository at this point in the history
…formation.
  • Loading branch information
thvasilo committed Jun 10, 2024
1 parent 71838ba commit cbfaa62
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
feat_name = feature_config.feat_name
args_dict = feature_config.transformation_kwargs
self.transformation: DistributedTransformation
# TODO: We will use this to re-apply transformations
# We use this to re-apply transformations
self.json_representation = json_representation

default_kwargs = {
Expand All @@ -63,7 +63,7 @@ def __init__(
self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "categorical":
self.transformation = DistCategoryTransformation(
**default_kwargs, **args_dict, spark=spark
**default_kwargs, **args_dict, spark=spark, json_representation=json_representation
)
elif feat_type == "multi-categorical":
self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict)
Expand All @@ -88,10 +88,17 @@ def apply_transformation(self, input_df: DataFrame) -> tuple[DataFrame, dict]:
"""
input_df = input_df.select(self.transformation.cols) # type: ignore

return (
self.transformation.apply(input_df),
self.transformation.get_json_representation(),
)
if self.json_representation:
logging.info("Applying precomputed transformation...")
return (
self.transformation.apply_precomputed_transformation(input_df),
self.json_representation,
)
else:
return (
self.transformation.apply(input_df),
self.transformation.get_json_representation(),
)

def get_transformation_name(self) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,23 @@

from abc import ABC, abstractmethod
from typing import Optional, Sequence
import logging

from pyspark.sql import DataFrame, SparkSession


class DistributedTransformation(ABC):
"""
Base class for all distributed transformations.
Parameters
----------
cols : Sequence[str]
Column names to which we will apply the transformation
spark : Optional[SparkSession], optional
Optional SparkSession if needed by the underlying implementation, by default None
json_representation : Optional[dict], optional
Pre-computed transformation representation to use, by default None
"""

def __init__(
Expand Down Expand Up @@ -52,6 +62,28 @@ def get_json_representation(self) -> dict:
else:
return {}

def apply_precomputed_transformation(self, input_df: DataFrame) -> DataFrame:
"""Applies a transformation using pre-computed representation.
Parameters
----------
input_df : DataFrame
Input DataFrame to apply the transformation to.
Returns
-------
DataFrame
The input DataFrame, modified according to the pre-computed transformation values.
"""
logging.warning(
(
"Transformation %s does not support pre-existing transform"
", applying new transformation"
),
self.get_transformation_name(),
)
return self.apply(input_df)

@staticmethod
@abstractmethod
def get_transformation_name() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
limitations under the License.
"""

from collections import defaultdict
from typing import List, Optional, Sequence
from functools import partial

import numpy as np
import pandas as pd

from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vectors
from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.sql.types import IntegerType

from graphstorm_processing.constants import (
MAX_CATEGORIES_PER_FEATURE,
Expand All @@ -41,8 +44,12 @@ class DistCategoryTransformation(DistributedTransformation):
Transforms categorical features into a vector of one-hot-encoded values.
"""

def __init__(self, cols: list[str], spark: SparkSession) -> None:
super().__init__(cols, spark)
def __init__(
self, cols: list[str], spark: SparkSession, json_representation: Optional[dict] = None
) -> None:
if not json_representation:
json_representation = {}
super().__init__(cols, spark, json_representation)

@staticmethod
def get_transformation_name() -> str:
Expand All @@ -51,6 +58,7 @@ def get_transformation_name() -> str:
def apply(self, input_df: DataFrame) -> DataFrame:
processed_col_names = []
top_categories_per_col: dict[str, list] = {}

for current_col in self.cols:
processed_col_names.append(current_col + "_processed")
distinct_category_counts = input_df.groupBy(current_col).count() # type: DataFrame
Expand Down Expand Up @@ -157,14 +165,90 @@ def apply(self, input_df: DataFrame) -> DataFrame:

# see get_json_representation() docstring for structure
self.json_representation = {
"string_indexer_labels_array": str_indexer_model.labelsArray,
"string_indexer_labels_arrays": str_indexer_model.labelsArray,
"cols": self.cols,
"per_col_label_to_one_hot_idx": per_col_label_to_one_hot_idx,
"transformation_name": self.get_transformation_name(),
}

return dense_vector_features

def apply_precomputed_transformation(self, input_df: DataFrame) -> DataFrame:

# Get JSON representation of categorical transformation
labels_arrays: list[list[str]] = self.json_representation["string_indexer_labels_arrays"]
per_col_label_to_one_hot_idx: dict[str, dict[str, int]] = self.json_representation[
"per_col_label_to_one_hot_idx"
]
precomputed_cols: list[str] = self.json_representation["cols"]

# Assertions to ensure correctness of representation
assert set(precomputed_cols) == set(self.cols), (
f"Mismatched columns in precomputed transformation: "
f"pre-computed cols: {sorted(precomputed_cols)}, "
f"columns in current config: {sorted(self.cols)}"
)
for col_labels, col in zip(labels_arrays, precomputed_cols):
for idx, label in enumerate(col_labels):
assert idx == per_col_label_to_one_hot_idx[col][label], (
"Mismatch between Spark labelsArray and pre-computed array index "
f"for col {col}, string: {label}, "
f"{idx} != {per_col_label_to_one_hot_idx[col][label]}"
)

# For each column in the transformation, we create a defaultdict
# with each unique value as keys, and the one-hot vector encoding
# of the value as value. Values not in the dict get the all zeroes (missing)
# vector
# Do this for each column in the transformation and return the resulting DF

# We need to define these outside the loop to avoid
# https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/cell-var-from-loop.html
def replace_col_in_row(val: str, str_to_vec: dict):
return str_to_vec[val]

def create_zeroes_list(vec_size: int):
return [0] * vec_size

transformed_df = None
already_transformed_cols = []
remaining_cols = list(self.cols)

for col_idx, current_col in enumerate(precomputed_cols):
vector_size = len(labels_arrays[col_idx])
# Mapping from string to one-hot vector,
# with all-zeroes default for unknown/missing values
string_to_vector = defaultdict(partial(create_zeroes_list, vector_size))

string_to_one_hot_idx = per_col_label_to_one_hot_idx[current_col]

# Populate the one-hot vectors for known strings
for string_val, one_hot_idx in string_to_one_hot_idx.items():
one_hot_vec = [0] * vector_size
one_hot_vec[one_hot_idx] = 1
string_to_vector[string_val] = one_hot_vec

# UDF that replaces strings values with their one-hot encoding (ohe)
replace_cur_col = partial(replace_col_in_row, str_to_vec=string_to_vector)
replace_cur_col_udf = F.udf(replace_cur_col, ArrayType(IntegerType()))

partial_df = transformed_df if transformed_df else input_df

transformed_col = f"{current_col}_ohe"
remaining_cols.remove(current_col)
# We maintain only the already transformed cols, and the ones yet to be transformed
transformed_df = partial_df.select(
replace_cur_col_udf(F.col(current_col)).alias(transformed_col),
*remaining_cols,
*already_transformed_cols,
).drop(current_col)
already_transformed_cols.append(transformed_col)

assert transformed_df
transformed_df = transformed_df.select(*already_transformed_cols).toDF(*self.cols)

return transformed_df

def get_json_representation(self) -> dict:
"""Representation of the single-category transformation for one or more columns.
Expand Down
48 changes: 26 additions & 22 deletions graphstorm-processing/graphstorm_processing/distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,25 @@ def __init__(
# Create the Spark session for execution
self.spark = spark_utils.create_spark_session(self.execution_env, self.filesystem_type)

# Initialize the graph loader
data_configs = create_config_objects(self.gsp_config_dict)
loader_config = HeterogeneousLoaderConfig(
add_reverse_edges=self.add_reverse_edges,
data_configs=data_configs,
enable_assertions=False,
graph_name=self.graph_name,
input_prefix=self.input_prefix,
local_input_path=self.local_config_path,
local_metadata_output_path=self.local_metadata_output_path,
num_output_files=self.num_output_files,
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)
self.loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
)

def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False):
"""Upload output files to S3
Expand Down Expand Up @@ -273,27 +292,10 @@ def run(self) -> None:
Executes the Spark processing job.
"""
logging.info("Performing data processing with PySpark...")
data_configs = create_config_objects(self.gsp_config_dict)

t0 = time.time()
# Prefer explicit arguments for clarity
loader_config = HeterogeneousLoaderConfig(
add_reverse_edges=self.add_reverse_edges,
data_configs=data_configs,
enable_assertions=False,
graph_name=self.graph_name,
input_prefix=self.input_prefix,
local_input_path=self.local_config_path,
local_metadata_output_path=self.local_metadata_output_path,
num_output_files=self.num_output_files,
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)
loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
)
processed_representations: ProcessedGraphRepresentation = loader.load()

processed_representations: ProcessedGraphRepresentation = self.loader.load()
graph_meta_dict = processed_representations.processed_graph_metadata_dict

t1 = time.time()
Expand Down Expand Up @@ -343,7 +345,9 @@ def run(self) -> None:

# If any of the metadata modification took place, write an updated metadata file
if updated_metadata:
updated_meta_path = os.path.join(loader.output_path, "updated_row_counts_metadata.json")
updated_meta_path = os.path.join(
self.loader.output_path, "updated_row_counts_metadata.json"
)
with open(
updated_meta_path,
"w",
Expand Down Expand Up @@ -384,7 +388,7 @@ def run(self) -> None:
# since we can't rely on SageMaker to do it
if self.filesystem_type == FilesystemType.S3:
self._upload_output_files(
loader, force=(not self.execution_env == ExecutionEnv.SAGEMAKER)
self.loader, force=(not self.execution_env == ExecutionEnv.SAGEMAKER)
)

def _merge_config_with_transformations(
Expand All @@ -406,7 +410,7 @@ def _merge_config_with_transformations(
"node_features": {
"node_type1": {
"feature_name1": {
"transformation": # transformation type
"transformation_name": # transformation name, e.g. "numerical"
# feature1 representation goes here
},
"feature_name2": {}, ...
Expand Down
Loading

0 comments on commit cbfaa62

Please sign in to comment.