Skip to content

Commit

Permalink
[GSProcessing] Add structure for saving transformation JSON files.
Browse files Browse the repository at this point in the history
First implemented for categorical transformation.
This commit only adds saving the categorical transformation in
a JSON representation.
  • Loading branch information
thvasilo committed May 28, 2024
1 parent f2a3523 commit fa77d57
Show file tree
Hide file tree
Showing 11 changed files with 370 additions and 130 deletions.
3 changes: 3 additions & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,6 @@
HUGGINGFACE_TRANFORM = "huggingface"
HUGGINGFACE_TOKENIZE = "tokenize_hf"
HUGGINGFACE_EMB = "embedding_hf"

########## Precomputed transformations ################
TRANSFORMATIONS_FILENAME = "precomputed_transformations.json"
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import logging

from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession

from graphstorm_processing.config.feature_config_base import FeatureConfig
from .dist_transformations import (
Expand All @@ -37,13 +37,19 @@ class DistFeatureTransformer(object):
which can then be be applied through a call to apply_transformation.
"""

def __init__(self, feature_config: FeatureConfig):
def __init__(
self, feature_config: FeatureConfig, spark: SparkSession, json_representation: dict
):
feat_type = feature_config.feat_type
feat_name = feature_config.feat_name
args_dict = feature_config.transformation_kwargs
self.transformation: DistributedTransformation
# TODO: We will use this to re-apply transformations
self.json_representation = json_representation

default_kwargs = {"cols": feature_config.cols}
default_kwargs = {
"cols": feature_config.cols,
}
logging.info("Feature name: %s", feat_name)
logging.info("Transformation type: %s", feat_type)

Expand All @@ -56,7 +62,9 @@ def __init__(self, feature_config: FeatureConfig):
elif feat_type == "bucket-numerical":
self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "categorical":
self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict)
self.transformation = DistCategoryTransformation(
**default_kwargs, **args_dict, spark=spark
)
elif feat_type == "multi-categorical":
self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict)
elif feat_type == "huggingface":
Expand All @@ -66,14 +74,24 @@ def __init__(self, feature_config: FeatureConfig):
f"Feature {feat_name} has type: {feat_type} that is not supported"
)

def apply_transformation(self, input_df: DataFrame) -> DataFrame:
def apply_transformation(self, input_df: DataFrame) -> tuple[DataFrame, dict]:
"""
Given an input dataframe, select only the relevant columns
Given an input DataFrame, select only the relevant columns
and apply the expected transformation to them.
Returns
-------
tuple[DataFrame, dict]
A tuple with two items, the first is the transformed input DataFrame,
the second is a JSON representation of the transformation. This will
allow us to apply the same transformation to new data.
"""
input_df = input_df.select(self.transformation.cols) # type: ignore

return self.transformation.apply(input_df)
return (
self.transformation.apply(input_df),
self.transformation.get_json_representation(),
)

def get_transformation_name(self) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,44 @@
"""

from abc import ABC, abstractmethod
from typing import Sequence
from typing import Optional, Sequence

from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession


class DistributedTransformation(ABC):
"""
Base class for all distributed transformations.
"""

def __init__(self, cols: Sequence[str]) -> None:
def __init__(
self,
cols: Sequence[str],
spark: Optional[SparkSession] = None,
json_representation: Optional[dict] = None,
) -> None:
self.cols = cols
self.spark = spark
self.json_representation = json_representation

@abstractmethod
def apply(self, input_df: DataFrame) -> DataFrame:
"""
Applies the transformation to the input DataFrame.
The returned dataframe will only contain the columns specified during initialization.
Applies the transformation to the input DataFrame, and returns the modified
DataFrame.
The returned DataFrame will only contain the columns specified during initialization.
"""

def get_json_representation(self) -> dict:
"""Get a JSON representation of the transformation."""
# TODO: Should we try to guarantee apply() has ran before this?
if self.json_representation:
return self.json_representation
else:
return {}

@staticmethod
@abstractmethod
def get_transformation_name() -> str:
"""
Get the name of the transformation
"""
"""Get the name of the transformation."""
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
limitations under the License.
"""

from typing import Dict, List, Optional, Sequence
from typing import List, Optional, Sequence

import numpy as np
import pandas as pd

from pyspark.sql import DataFrame, functions as F
from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
Expand All @@ -40,18 +41,19 @@ class DistCategoryTransformation(DistributedTransformation):
Transforms categorical features into a vector of one-hot-encoded values.
"""

def __init__(self, cols: List[str]) -> None:
super().__init__(cols)
def __init__(self, cols: List[str], spark: SparkSession) -> None:
super().__init__(cols, spark)

@staticmethod
def get_transformation_name() -> str:
return "DistCategoryTransformation"

def apply(self, input_df: DataFrame) -> DataFrame:
processed_col_names = []
for col in self.cols:
processed_col_names.append(col + "_processed")
distinct_category_counts = input_df.groupBy(col).count() # type: DataFrame
top_categories_per_col: dict[str, list] = {}
for current_col in self.cols:
processed_col_names.append(current_col + "_processed")
distinct_category_counts = input_df.groupBy(current_col).count() # type: DataFrame
num_distinct_categories = distinct_category_counts.count()

# Conditionally replace rare categories with single placeholder
Expand All @@ -60,17 +62,23 @@ def apply(self, input_df: DataFrame) -> DataFrame:
MAX_CATEGORIES_PER_FEATURE - 1
)
top_categories_set = {row[0] for row in top_categories}
top_categories_per_col[current_col] = list(top_categories_set)
# TODO: Ideally we don't want to use withColumn in a loop
input_df = input_df.withColumn(
col,
when(input_df[col].isin(top_categories_set), input_df[col]).otherwise(
RARE_CATEGORY
),
current_col,
when(
input_df[current_col].isin(top_categories_set), input_df[current_col]
).otherwise(RARE_CATEGORY),
)
else:
top_categories_per_col[current_col] = [
x[current_col] for x in distinct_category_counts.select(current_col).collect()
]

# Replace empty string cols with None
input_df = input_df.withColumn(
col, when(input_df[col] == "", None).otherwise(input_df[col])
current_col,
when(input_df[current_col] == "", None).otherwise(input_df[current_col]),
)

# We first convert the strings to float indexes
Expand Down Expand Up @@ -105,8 +113,74 @@ def apply(self, input_df: DataFrame) -> DataFrame:
]
)

# Structure: {column_name: {category_string: index_value, ...}. ...}
per_col_label_to_one_hot_idx: dict[str, dict[str, int]] = {}

# To get the transformed values for each value in each col
# we need to create a DataFrame with the top categories for the current
# col, then fill in the rest of the values with placeholders
# and pass the generated DF through the one-hot encoder
for current_col, processed_col in zip(self.cols, processed_col_names):
other_cols = [x for x in self.cols if x != current_col]
top_str_categories_list = top_categories_per_col[current_col]
# Spark doesn't model missing values, the all-zeroes vector is used
top_str_categories_list.remove(None)
print(top_str_categories_list)
# Each col might have different number of top categories, we need one DF per col
num_current_col_cats = len(top_str_categories_list)
# We don't care about values for the other cols in this iteration,
# just fill with empty string
placeholder_vals = [""] * num_current_col_cats
placeholder_cols = [placeholder_vals for _ in range(len(self.cols) - 1)]
current_col_unique_vals = [list(top_str_categories_list)]
# We need to create a DF where all cols have num_rows == num_current_col_cats
# and the current col needs to be the first col in the DF.
vals_dict = dict(
zip([current_col] + other_cols, current_col_unique_vals + placeholder_cols)
)

# One hot encoder expects a DF with all cols that were used to train it
# so we use the top-MAX_CATEGORIES_PER_FEATURE values for the current col,
# and the placeholders for the rest
top_str_categories_df = self.spark.createDataFrame(pd.DataFrame(vals_dict))
top_indexed_categories_df = str_indexer_model.transform(top_str_categories_df)

# For the current col, get the one-hot index for each of its category strings
# by passing the top-k values DF through the one-hot encoder model
per_col_label_to_one_hot_idx[current_col] = {
x[current_col]: int(x[processed_col])
for x in one_hot_encoder_model.transform(top_indexed_categories_df).collect()
}

# see get_json_representation() docstring for structure
self.json_representation = {
"string_indexer_labels_array": str_indexer_model.labelsArray,
"cols": self.cols,
"per_col_label_to_one_hot_idx": per_col_label_to_one_hot_idx,
}

return dense_vector_features

def get_json_representation(self) -> dict:
"""Representation of the single-category transformation for one or more columns.
Returns
-------
dict
Structure:
string_indexer_labels_array:
tuple[tuple[str]], outer tuple has num_cols elements,
each inner tuple has num_cats elements, each str is a category string.
Spark uses this to represent the one-hot index for each category, its
position in the inner tuple is the one-hot-index position for the string.
Categories are sorted by their frequency in the data.
cols:
list[str], with num_cols elements
per_col_label_to_one_hot_idx:
dict[str, dict[str, int]], with num_cols elements, each with num_categories elements
"""
return self.json_representation


class DistMultiCategoryTransformation(DistributedTransformation):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@

from graphstorm_processing.graph_loaders.dist_heterogeneous_loader import (
DistHeterogeneousGraphLoader,
HeterogeneousLoaderConfig,
)
from graphstorm_processing.config.config_parser import create_config_objects
from graphstorm_processing.config.config_conversion import GConstructConfigConverter
from graphstorm_processing.constants import TRANSFORMATIONS_FILENAME
from graphstorm_processing.data_transformations import spark_utils, s3_utils
from graphstorm_processing.repartition_files import (
repartition_files,
Expand Down Expand Up @@ -183,6 +185,17 @@ def __init__(
with open(graph_conf, "r", encoding="utf-8") as f:
dataset_config_dict: Dict[str, Any] = json.load(f)

# Load the pre-computed transformations if the file exists
if os.path.exists(os.path.join(self.local_config_path, TRANSFORMATIONS_FILENAME)):
with open(
os.path.join(self.local_config_path, TRANSFORMATIONS_FILENAME),
"r",
encoding="utf-8",
) as f:
self.precomputed_transformations = json.load(f)
else:
self.precomputed_transformations = {}

if "version" in dataset_config_dict:
config_version = dataset_config_dict["version"]
if config_version == "gsprocessing-v1.0":
Expand Down Expand Up @@ -247,18 +260,22 @@ def run(self) -> None:
data_configs = create_config_objects(self.graph_config_dict)

t0 = time.time()
# Prefer explicit arguments for clarity
loader = DistHeterogeneousGraphLoader(
spark=self.spark,
local_input_path=self.local_config_path,
local_output_path=self.local_output_path,
data_configs=data_configs,
input_prefix=self.input_prefix,
output_prefix=self.output_prefix,
num_output_files=self.num_output_files,
loader_config = HeterogeneousLoaderConfig(
add_reverse_edges=self.add_reverse_edges,
data_configs=data_configs,
enable_assertions=False,
graph_name=self.graph_name,
input_prefix=self.input_prefix,
local_input_path=self.local_config_path,
local_output_path=self.local_output_path,
num_output_files=self.num_output_files,
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)
# Prefer explicit arguments for clarity
loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
)
graph_meta_dict = loader.load()

Expand Down Expand Up @@ -439,6 +456,15 @@ def main():
f"s3://{input_bucket}/{input_s3_prefix}/"
f"{gsprocessing_args.config_filename}"
) from e
# Try to download the pre-computed transformations file, if it exists
try:
s3.download_file(
input_bucket,
f"{input_s3_prefix}/{TRANSFORMATIONS_FILENAME}",
os.path.join(tempdir.name, TRANSFORMATIONS_FILENAME),
)
except botocore.exceptions.ClientError as _:
pass
local_config_path = tempdir.name

# local output location for metadata files
Expand Down
Loading

0 comments on commit fa77d57

Please sign in to comment.