Skip to content

Commit

Permalink
[GSProcessing] Add support for bucket transformations (#583)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*

It is still under development, first to create the PR and add logging
information for feature developing use

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Theodore Vasiloudis <[email protected]>
  • Loading branch information
jalencato and thvasilo authored Nov 1, 2023
1 parent d1e8b9c commit 841718e
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 4 deletions.
19 changes: 19 additions & 0 deletions docs/source/gs-processing/developer/input-configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,26 @@ arguments.
- ``separator`` (String, optional): Same as for ``no-op`` transformation, used to separate numerical
values in CSV input. If the input data are in Parquet format, each value in the
column is assumed to be an array of floats.
- ``bucket-numerical``

- Transforms a numerical column to a one-hot or multi-hot bucket representation, using bucketization.
Also supports optional missing value imputation through the `imputer` kwarg.```
- ``kwargs``:

- ``imputer`` (String, optional): A method to fill in missing values in the data.
Valid values are:
``none`` (Default), ``mean``, ``median``, and ``most_frequent``. Missing values will be replaced
with the respective value computed from the data.
- ``range`` (List[float], required), The range defines the start and end point of the buckets with ``[a, b]``. It should be
a list of two floats. For example, ``[10, 30]`` defines a bucketing range between 10 and 30.
- ``bucket_cnt`` (Integer, required), The count of bucket lists used in the bucket feature transform. GSProcessing
calculates the size of each bucket as ``( b - a ) / c`` , and encodes each numeric value as the number
of whatever bucket it falls into. Any value less than a is considered to belong in the first bucket,
and any value greater than b is considered to belong in the last bucket.
- ``slide_window_size`` (Integer, optional), slide_window_size can be used to make numeric values fall into more than one bucket,
by specifying a slide-window size ``s``, where ``s`` can an integer or float. GSProcessing then transforms each
numeric value ``v`` of the property into a range from ``v - s/2`` through ``v + s/2`` , and assigns the value v
to every bucket that the range covers.
--------------

Examples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,21 @@ def _convert_feature(feats: list[dict]) -> list[dict]:

if gconstruct_transform_dict["name"] == "max_min_norm":
gsp_transformation_dict["name"] = "numerical"
gsp_transformation_dict["kwargs"] = {"normalizer": "min-max", "imputer": "mean"}
gsp_transformation_dict["kwargs"] = {"normalizer": "min-max", "imputer": "none"}
elif gconstruct_transform_dict["name"] == "bucket_numerical":
gsp_transformation_dict["name"] = "bucket-numerical"
assert (
"bucket_cnt" in gconstruct_transform_dict
), "bucket_cnt should be in the gconstruct bucket feature transform field"
assert (
"range" in gconstruct_transform_dict
), "range should be in the gconstruct bucket feature transform field"
gsp_transformation_dict["kwargs"] = {
"bucket_cnt": gconstruct_transform_dict["bucket_cnt"],
"range": gconstruct_transform_dict["range"],
"slide_window_size": gconstruct_transform_dict["slide_window_size"],
"imputer": "none",
}
# TODO: Add support for other common transformations here
else:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from graphstorm_processing.constants import SUPPORTED_FILE_TYPES
from .label_config_base import LabelConfig, EdgeLabelConfig, NodeLabelConfig
from .feature_config_base import FeatureConfig, NoopFeatureConfig
from .numerical_configs import MultiNumericalFeatureConfig, NumericalFeatureConfig
from .numerical_configs import (
BucketNumericalFeatureConfig,
MultiNumericalFeatureConfig,
NumericalFeatureConfig,
)
from .data_config_base import DataStorageConfig


Expand Down Expand Up @@ -56,6 +60,8 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig:
return NumericalFeatureConfig(feature_dict)
elif transformation_name == "multi-numerical":
return MultiNumericalFeatureConfig(feature_dict)
elif transformation_name == "bucket-numerical":
return BucketNumericalFeatureConfig(feature_dict)
else:
raise RuntimeError(f"Unknown transformation name: '{transformation_name}'")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""
from typing import Mapping
import numbers

from graphstorm_processing.constants import VALID_IMPUTERS, VALID_NORMALIZERS
from .feature_config_base import FeatureConfig
Expand Down Expand Up @@ -92,3 +93,53 @@ def __init__(self, config: Mapping):
self.separator = self._transformation_kwargs.get("separator", None)

self._sanity_check()


class BucketNumericalFeatureConfig(FeatureConfig):
"""Feature configuration for bucket-numerical transformation.
Supported kwargs
----------------
imputer: str
A method to fill in missing values in the data. Valid values are:
"none" (Default), "mean", "median", and "most_frequent". Missing values will be replaced
with the respective value computed from the data.
bucket_cnt: int
The count of bucket lists used in the bucket feature transform. Each bucket will
have same length.
range: List[float]
The range of bucket lists only defining the start and end point. The range and
bucket_cnt will define the buckets together. For example, range = [10, 30] and
bucket_cnt = 2, will have two buckets: [10, 20] and [20, 30]
slide_window_size: float or none
Interval or range within which numeric values are grouped into buckets. Slide window
size will let one value possibly fall into multiple buckets.
"""

def __init__(self, config: Mapping):
super().__init__(config)
self.imputer = self._transformation_kwargs.get("imputer", "none")
self.bucket_cnt = self._transformation_kwargs.get("bucket_cnt", "none")
self.range = self._transformation_kwargs.get("range", "none")
self.slide_window_size = self._transformation_kwargs.get("slide_window_size", "none")
self._sanity_check()

def _sanity_check(self) -> None:
super()._sanity_check()
assert (
self.imputer in VALID_IMPUTERS
), f"Unknown imputer requested, expected one of {VALID_IMPUTERS}, got {self.imputer}"
assert isinstance(
self.bucket_cnt, int
), f"Expect bucket_cnt {self.bucket_cnt} be an integer"
assert (
isinstance(self.range, list)
and all(isinstance(x, numbers.Number) for x in self.range)
and len(self.range) == 2
), f"Expect range {self.range} be a list of two integers or two floats"
assert (
isinstance(self.slide_window_size, numbers.Number) or self.slide_window_size == "none"
), f"Expect no slide window size or expect {self.slide_window_size} is a number"
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
NoopTransformation,
DistNumericalTransformation,
DistMultiNumericalTransformation,
DistBucketNumericalTransformation,
)


Expand All @@ -48,6 +49,8 @@ def __init__(self, feature_config: FeatureConfig):
self.transformation = DistNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "multi-numerical":
self.transformation = DistMultiNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "bucket-numerical":
self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict)
else:
raise NotImplementedError(
f"Feature {feat_name} has type: {feat_type} that is not supported"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
DistMultiNumericalTransformation,
DistNumericalTransformation,
)
from .dist_bucket_numerical_transformation import DistBucketNumericalTransformation
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List

from pyspark.sql import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType
import numpy as np

from .base_dist_transformation import DistributedTransformation
from .dist_numerical_transformation import apply_imputation


class DistBucketNumericalTransformation(DistributedTransformation):
"""Transformation to apply missing value imputation and bucket normalization
to a numerical input.
Parameters
----------
cols : Sequence[str]
The list of columns to apply the transformations on.
range: List[float]
The range of bucket lists only defining the start and end point.
bucket_cnt: int
The count of bucket lists used in the bucket feature transform.
slide_window_size: float
Interval or range within which numeric values are grouped into buckets.
imputer : str
The type of missing value imputation to apply to the column.
Valid values are "mean", "median" and "most_frequent".
"""

# pylint: disable=redefined-builtin
def __init__(
self,
cols: List[str],
range: List[float],
bucket_cnt: int,
slide_window_size: float = 0.0,
imputer: str = "none",
) -> None:
super().__init__(cols)
self.cols = cols
assert len(self.cols) == 1, "Bucket numerical transformation only supports single column"
self.range = range
self.bucket_count = bucket_cnt
self.slide_window_size = slide_window_size
# Spark uses 'mode' for the most frequent element
self.shared_imputation = "mode" if imputer == "most_frequent" else imputer

@staticmethod
def get_transformation_name() -> str:
return "DistBucketNumericalTransformation"

def apply(self, input_df: DataFrame) -> DataFrame:
imputed_df = apply_imputation(self.cols, self.shared_imputation, input_df)
# TODO: Make range optional by getting min/max from data.
min_val, max_val = self.range

bucket_size = (max_val - min_val) / self.bucket_count
epsilon = bucket_size / 10

# TODO: Test if pyspark.ml.feature.Bucketizer covers our requirements and is faster
def determine_bucket_membership(value: float) -> List[int]:
# Create value range, value -> [value - slide/2, value + slide/2]
high_val = value + self.slide_window_size / 2
low_val = value - self.slide_window_size / 2

# Early exits to avoid numpy calls
membership_list = [0.0] * self.bucket_count
if value >= max_val:
membership_list[-1] = 1.0
return membership_list
if value <= min_val:
membership_list[0] = 1.0
return membership_list

# Upper and lower threshold the value range
if low_val < min_val:
low_val = min_val
elif low_val >= max_val:
low_val = max_val - epsilon
if high_val < min_val:
high_val = min_val
elif high_val >= max_val:
high_val = max_val - epsilon

# Determine upper and lower bucket membership
low_val -= min_val
high_val -= min_val
low_idx = low_val // bucket_size
high_idx = (high_val // bucket_size) + 1

idx = np.arange(start=low_idx, stop=high_idx, dtype=int)
membership_array = np.zeros(self.bucket_count, dtype=float)
membership_array[idx] = 1.0

return membership_array.tolist()

# TODO: Try using a Pandas/Arrow UDF here and compare performance.
bucket_udf = F.udf(determine_bucket_membership, ArrayType(FloatType()))

bucketized_df = imputed_df.select(bucket_udf(F.col(self.cols[0])).alias(self.cols[0]))

return bucketized_df
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def determine_spark_feature_type(feature_type: str) -> Type[DataType]:
# TODO: Replace with pattern matching after moving to Python 3.10?
if feature_type in ["no-op", "multi-numerical"] or feature_type.startswith("text"):
return StringType
if feature_type in ["numerical", "none"]:
if feature_type in ["numerical", "bucket-numerical", "none"]:
return FloatType
else:
raise NotImplementedError(f"Unknown feature type: {feature_type}")
Expand Down
23 changes: 22 additions & 1 deletion graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter):
"features": [
{"feature_col": ["citation_time"], "feature_name": "feat"},
{"feature_col": ["num_citations"], "transform": {"name": "max_min_norm"}},
{
"feature_col": ["num_citations"],
"transform": {
"name": "bucket_numerical",
"bucket_cnt": 9,
"range": [10, 100],
"slide_window_size": 5,
},
},
],
"labels": [
{"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]}
Expand Down Expand Up @@ -249,7 +258,19 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter):
"column": "num_citations",
"transformation": {
"name": "numerical",
"kwargs": {"normalizer": "min-max", "imputer": "mean"},
"kwargs": {"normalizer": "min-max", "imputer": "none"},
},
},
{
"column": "num_citations",
"transformation": {
"name": "bucket-numerical",
"kwargs": {
"bucket_cnt": 9,
"range": [10, 100],
"slide_window_size": 5,
"imputer": "none",
},
},
},
]
Expand Down
Loading

0 comments on commit 841718e

Please sign in to comment.