Skip to content

Commit

Permalink
[GSProcessing] Add pre-computed categorical transformation loading (#870
Browse files Browse the repository at this point in the history
)

*Issue #, if available:*

*Description of changes:*

* Follow-up to #857 
* Allow us to re-apply a previously saved categorical transformation to
new data. See below for design details.

To be able to re-apply the categorical transformations that we create
using the code in #857 , we first create a mapping from original string
to one-hot representation, that we read from the saved JSON file, then
use a UDF to use the mapping(s) on the column(s).

The `DistributedTransformation` class from which all transformation
implementations inherit, gains a new function,
`apply_precomputed_transformation`. When a pre-computed transformation
JSON file exists in the input, and the feature is one of those listed in
that file, we use this function to re-apply the existing transformation
instead of creating a new one.

The default implementation for `apply_precomputed_transformation` is to
log a warning and apply a new transformation.

When we implement a pre-computed transform for a new transformation
(e.g. numerical) we need to:
* Ensure the the transformation's `self.json_representation` is
populated during the call to `apply()`. This ensures the transformation
info will be saved in the output JSON.
* Override the `apply_precomputed_transformation` function (as we did
for `DistCategoryTransformation` here), so that it uses the dict loaded
from the JSON file to re-apply the transformation to the new data.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
thvasilo authored Jun 17, 2024
1 parent dbf82fa commit 5199149
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 98 deletions.
16 changes: 16 additions & 0 deletions docs/source/gs-processing/gs-processing-getting-started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,22 @@ GSProcessing supports both the GConstruct JSON configuration format,
as well as its own GSProcessing config. You can learn about the
GSProcessing JSON configuration in :doc:`developer/input-configuration`.

Re-applying feature transformations to new data
-----------------------------------------------

Often you will process your data at training time and run inference at later
dates. If your data changes in the meantime. e.g. new values appear in a
categorical feature, you'll need to ensure no new values appear in the transformed
data, as the trained model relies on pre-existing values only.

To achieve that, GSProcessing creates an additional file in the output,
named ``precomputed_transformations.json``. To ensure the same transformations
are applied to new data, you can copy this file to the top-level path of your
new input data, and GSProcessing will pick up any transformations there to ensure
the produced data match the ones that were used to train your model.

Currently, we only support re-applying transformations for categorical features.


Developer guide
---------------
Expand Down
36 changes: 27 additions & 9 deletions docs/source/gs-processing/usage/example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ we can use the following command to run the processing job locally:

.. code-block:: bash
gs-processing --config-filename gconstruct-config.json \
gs-processing --config-filename gsprocessing-config.json \
--input-prefix ./tests/resources/small_heterogeneous_graph \
--output-prefix /tmp/gsprocessing-example/ \
--do-repartition True
Expand Down Expand Up @@ -211,26 +211,44 @@ and can be used downstream to create a partitioned graph.
.. code-block:: bash
$ cd /tmp/gsprocessing-example
$ ls
edges/ launch_arguments.json metadata.json node_data/
raw_id_mappings/ perf_counters.json updated_row_counts_metadata.json
$ ls -l
edges/
gsprocessing-config_with_transformations.json
launch_arguments.json
metadata.json
node_data/
perf_counters.json
precomputed_transformations.json
raw_id_mappings/
updated_row_counts_metadata.json
We have a few JSON files and the data directories containing
the graph structure, features, and labels. In more detail:

* ``gsprocessing-config_with_transformations.json``: This is the input configuration
we used, modified to include representations of any supported transformations
we applied. This file can be used to re-apply the transformations on new data.
* ``launch_arguments.json``: Contains the arguments that were used
to launch the processing job, allowing you to check the parameters after the
job finishes.
* ``updated_row_counts_metadata.json``:
This file is meant to be used as the input configuration for the
distributed partitioning pipeline. ``gs-repartition`` produces
this file using the original ``metadata.json`` file as input.
* ``metadata.json``: Created by ``gs-processing`` and used as input
for ``gs-repartition``, can be removed the ``gs-repartition`` step.
* ``perf_counters.json``: A JSON file that contains runtime measurements
for the various components of GSProcessing. Can be used to profile the
application and discover bottlenecks.
* ``precomputed_transformations.json``: A JSON file that contains representations
of supported transformations. To re-use these transformations on another dataset,
place this file in the top level of another set of raw data, at the same level
as the input GSProcessing/GConstruct configuration JSON file.
GSProcessing will use the transformation values listed here
instead of creating new ones, ensuring that models trained with the original
data can still be used in the newly transformed data. Currently only
categorical transformations can be re-applied.
* ``updated_row_counts_metadata.json``:
This file is meant to be used as the input configuration for the
distributed partitioning pipeline. ``gs-repartition`` produces
this file using the original ``metadata.json`` file as input.

The directories created contain:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
feat_name = feature_config.feat_name
args_dict = feature_config.transformation_kwargs
self.transformation: DistributedTransformation
# TODO: We will use this to re-apply transformations
# We use this to re-apply transformations
self.json_representation = json_representation

default_kwargs = {
Expand All @@ -63,7 +63,7 @@ def __init__(
self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "categorical":
self.transformation = DistCategoryTransformation(
**default_kwargs, **args_dict, spark=spark
**default_kwargs, **args_dict, spark=spark, json_representation=json_representation
)
elif feat_type == "multi-categorical":
self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict)
Expand All @@ -88,10 +88,17 @@ def apply_transformation(self, input_df: DataFrame) -> tuple[DataFrame, dict]:
"""
input_df = input_df.select(self.transformation.cols) # type: ignore

return (
self.transformation.apply(input_df),
self.transformation.get_json_representation(),
)
if self.json_representation:
logging.info("Applying precomputed transformation...")
return (
self.transformation.apply_precomputed_transformation(input_df),
self.json_representation,
)
else:
return (
self.transformation.apply(input_df),
self.transformation.get_json_representation(),
)

def get_transformation_name(self) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
class DistributedTransformation(ABC):
"""
Base class for all distributed transformations.
Parameters
----------
cols : Sequence[str]
Column names to which we will apply the transformation
spark : Optional[SparkSession], optional
Optional SparkSession if needed by the underlying implementation, by default None
json_representation : Optional[dict], optional
Pre-computed transformation representation to use, by default None
"""

def __init__(
Expand Down Expand Up @@ -52,6 +61,27 @@ def get_json_representation(self) -> dict:
else:
return {}

def apply_precomputed_transformation(self, input_df: DataFrame) -> DataFrame:
"""Applies a transformation using pre-computed representation.
Parameters
----------
input_df : DataFrame
Input DataFrame to apply the transformation to.
Returns
-------
DataFrame
The input DataFrame, modified according to the pre-computed transformation values.
Raises
------
NotImplementedError
If the underlying transformation does not support re-applying using JSON input.
"""
raise NotImplementedError(
f"Pre-computed transformation not available for {self.get_transformation_name()}"
)

@staticmethod
@abstractmethod
def get_transformation_name() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
limitations under the License.
"""

from collections import defaultdict
from typing import List, Optional, Sequence
from functools import partial

import numpy as np
import pandas as pd

from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vectors
from pyspark.sql import DataFrame, functions as F, SparkSession
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.sql.types import IntegerType

from graphstorm_processing.constants import (
MAX_CATEGORIES_PER_FEATURE,
Expand All @@ -41,8 +44,12 @@ class DistCategoryTransformation(DistributedTransformation):
Transforms categorical features into a vector of one-hot-encoded values.
"""

def __init__(self, cols: list[str], spark: SparkSession) -> None:
super().__init__(cols, spark)
def __init__(
self, cols: list[str], spark: SparkSession, json_representation: Optional[dict] = None
) -> None:
if not json_representation:
json_representation = {}
super().__init__(cols, spark, json_representation)

@staticmethod
def get_transformation_name() -> str:
Expand All @@ -51,6 +58,7 @@ def get_transformation_name() -> str:
def apply(self, input_df: DataFrame) -> DataFrame:
processed_col_names = []
top_categories_per_col: dict[str, list] = {}

for current_col in self.cols:
processed_col_names.append(current_col + "_processed")
distinct_category_counts = input_df.groupBy(current_col).count() # type: DataFrame
Expand Down Expand Up @@ -157,14 +165,95 @@ def apply(self, input_df: DataFrame) -> DataFrame:

# see get_json_representation() docstring for structure
self.json_representation = {
"string_indexer_labels_array": str_indexer_model.labelsArray,
"string_indexer_labels_arrays": str_indexer_model.labelsArray,
"cols": self.cols,
"per_col_label_to_one_hot_idx": per_col_label_to_one_hot_idx,
"transformation_name": self.get_transformation_name(),
}

return dense_vector_features

def apply_precomputed_transformation(self, input_df: DataFrame) -> DataFrame:

# List of StringIndexerModel labelsArray lists, each one containing the strings
# for one column. See docs for pyspark.ml.feature.StringIndexerModel.labelsArray
labels_arrays: list[list[str]] = self.json_representation["string_indexer_labels_arrays"]
# More verbose representation of the mapping from string to one hot index location,
# for each column in the input.
per_col_label_to_one_hot_idx: dict[str, dict[str, int]] = self.json_representation[
"per_col_label_to_one_hot_idx"
]
# The list of cols the transformation was originally applied to.
precomputed_cols: list[str] = self.json_representation["cols"]

# Assertions to ensure correctness of representation
assert set(precomputed_cols) == set(self.cols), (
f"Mismatched columns in precomputed transformation: "
f"pre-computed cols: {sorted(precomputed_cols)}, "
f"columns in current config: {sorted(self.cols)}, "
f"different items: {set(precomputed_cols).symmetric_difference(set(self.cols))}"
)
for col_labels, col in zip(labels_arrays, precomputed_cols):
for idx, label in enumerate(col_labels):
assert idx == per_col_label_to_one_hot_idx[col][label], (
"Mismatch between Spark labelsArray and pre-computed array index "
f"for col {col}, string: {label}, "
f"{idx} != {per_col_label_to_one_hot_idx[col][label]}"
)

# For each column in the transformation, we create a defaultdict
# with each unique value as keys, and the one-hot vector encoding
# of the value as value. Values not in the dict get the all zeroes (missing)
# vector
# Do this for each column in the transformation and return the resulting DF

# We need to define these outside the loop to avoid
# https://pylint.readthedocs.io/en/latest/user_guide/messages/warning/cell-var-from-loop.html
def replace_col_in_row(val: str, str_to_vec: dict):
return str_to_vec[val]

def create_zeroes_list(vec_size: int):
return [0] * vec_size

transformed_df = None
already_transformed_cols = []
remaining_cols = list(self.cols)

for col_idx, current_col in enumerate(precomputed_cols):
vector_size = len(labels_arrays[col_idx])
# Mapping from string to one-hot vector,
# with all-zeroes default for unknown/missing values
string_to_vector = defaultdict(partial(create_zeroes_list, vector_size))

string_to_one_hot_idx = per_col_label_to_one_hot_idx[current_col]

# Populate the one-hot vectors for known strings
for string_val, one_hot_idx in string_to_one_hot_idx.items():
one_hot_vec = [0] * vector_size
one_hot_vec[one_hot_idx] = 1
string_to_vector[string_val] = one_hot_vec

# UDF that replaces strings values with their one-hot encoding (ohe)
replace_cur_col = partial(replace_col_in_row, str_to_vec=string_to_vector)
replace_cur_col_udf = F.udf(replace_cur_col, ArrayType(IntegerType()))

partial_df = transformed_df if transformed_df else input_df

transformed_col = f"{current_col}_ohe"
remaining_cols.remove(current_col)
# We maintain only the already transformed cols, and the ones yet to be transformed
transformed_df = partial_df.select(
replace_cur_col_udf(F.col(current_col)).alias(transformed_col),
*remaining_cols,
*already_transformed_cols,
).drop(current_col)
already_transformed_cols.append(transformed_col)

assert transformed_df
transformed_df = transformed_df.select(*already_transformed_cols).toDF(*self.cols)

return transformed_df

def get_json_representation(self) -> dict:
"""Representation of the single-category transformation for one or more columns.
Expand Down
48 changes: 26 additions & 22 deletions graphstorm-processing/graphstorm_processing/distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,25 @@ def __init__(
# Create the Spark session for execution
self.spark = spark_utils.create_spark_session(self.execution_env, self.filesystem_type)

# Initialize the graph loader
data_configs = create_config_objects(self.gsp_config_dict)
loader_config = HeterogeneousLoaderConfig(
add_reverse_edges=self.add_reverse_edges,
data_configs=data_configs,
enable_assertions=False,
graph_name=self.graph_name,
input_prefix=self.input_prefix,
local_input_path=self.local_config_path,
local_metadata_output_path=self.local_metadata_output_path,
num_output_files=self.num_output_files,
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)
self.loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
)

def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False):
"""Upload output files to S3
Expand Down Expand Up @@ -273,27 +292,10 @@ def run(self) -> None:
Executes the Spark processing job.
"""
logging.info("Performing data processing with PySpark...")
data_configs = create_config_objects(self.gsp_config_dict)

t0 = time.time()
# Prefer explicit arguments for clarity
loader_config = HeterogeneousLoaderConfig(
add_reverse_edges=self.add_reverse_edges,
data_configs=data_configs,
enable_assertions=False,
graph_name=self.graph_name,
input_prefix=self.input_prefix,
local_input_path=self.local_config_path,
local_metadata_output_path=self.local_metadata_output_path,
num_output_files=self.num_output_files,
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)
loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
)
processed_representations: ProcessedGraphRepresentation = loader.load()

processed_representations: ProcessedGraphRepresentation = self.loader.load()
graph_meta_dict = processed_representations.processed_graph_metadata_dict

t1 = time.time()
Expand Down Expand Up @@ -343,7 +345,9 @@ def run(self) -> None:

# If any of the metadata modification took place, write an updated metadata file
if updated_metadata:
updated_meta_path = os.path.join(loader.output_path, "updated_row_counts_metadata.json")
updated_meta_path = os.path.join(
self.loader.output_path, "updated_row_counts_metadata.json"
)
with open(
updated_meta_path,
"w",
Expand Down Expand Up @@ -384,7 +388,7 @@ def run(self) -> None:
# since we can't rely on SageMaker to do it
if self.filesystem_type == FilesystemType.S3:
self._upload_output_files(
loader, force=(not self.execution_env == ExecutionEnv.SAGEMAKER)
self.loader, force=(not self.execution_env == ExecutionEnv.SAGEMAKER)
)

def _merge_config_with_transformations(
Expand All @@ -406,7 +410,7 @@ def _merge_config_with_transformations(
"node_features": {
"node_type1": {
"feature_name1": {
"transformation": # transformation type
"transformation_name": # transformation name, e.g. "numerical"
# feature1 representation goes here
},
"feature_name2": {}, ...
Expand Down
Loading

0 comments on commit 5199149

Please sign in to comment.