Skip to content

Commit

Permalink
[GSProcessing] Categorical Feature Transformation (#623)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Theodore Vasiloudis <[email protected]>
  • Loading branch information
jalencato and thvasilo authored Nov 9, 2023
1 parent ea18f9a commit 82f8e20
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 14 deletions.
19 changes: 19 additions & 0 deletions docs/source/gs-processing/developer/input-configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,25 @@ arguments.
by specifying a slide-window size ``s``, where ``s`` can an integer or float. GSProcessing then transforms each
numeric value ``v`` of the property into a range from ``v - s/2`` through ``v + s/2`` , and assigns the value v
to every bucket that the range covers.

- ``categorical``

- Transforms values from a fixed list of possible values (categorical features) to a one-hot encoding.
The length of the resulting vector will be the number of categories in the data minus one, with a 1 in
the index of the single category, and zero everywhere else.

.. note::
The maximum number of categories in any categorical feature is 100. If a property has more than 100 categories of value,
only the most common 99 of them are placed in distinct categories, and the rest are placed in a special category named OTHER.

- ``multi-categorical``

- Encodes vector-like data from a fixed list of possible values (i.e. multi-label/multi-categorical data) using a multi-hot encoding. The length of the resulting vector will be the number of categories in the data minus one, and each value will have a 1 value for every category that appears, and 0 everwhere else.
- ``kwargs``:
- ``separator`` (String, optional): Same as the one in the No-op operation, the separator is used to
split multiple input values for CSV files e.g. ``detective|noir``. If it is not provided, then the whole value
will be considered as an array. For Parquet files, if the input type is ArrayType(StringType()), then the
separator is ignored; if it is StringType(), it will apply same logic as in CSV.
--------------

Examples
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Mapping
from .feature_config_base import FeatureConfig


class MultiCategoricalFeatureConfig(FeatureConfig):
"""Feature configuration for multi-column categorical features.
Supported kwargs
----------------
separator: str, optional
A separator to use when splitting a delimited string into multiple numerical values
as a vector. Only applicable to CSV input. Example: for a separator `'|'` the CSV
value `1|2|3` would be transformed to a vector, `[1, 2, 3]`. When `None` the expected
input format is an array of string values.
"""

def __init__(self, config: Mapping):
super().__init__(config)
self.separator = self._transformation_kwargs.get("separator", None)

self._sanity_check()
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def _convert_feature(feats: list[dict]) -> list[dict]:
"normalizer": "rank-gauss",
"imputer": "none",
}
elif gconstruct_transform_dict["name"] == "to_categorical":
if "separator" in gconstruct_transform_dict:
gsp_transformation_dict["name"] = "multi-categorical"
gsp_transformation_dict["kwargs"] = {
"separator": gconstruct_transform_dict["separator"]
}
else:
gsp_transformation_dict["name"] = "categorical"
gsp_transformation_dict["kwargs"] = {}
# TODO: Add support for other common transformations here
else:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MultiNumericalFeatureConfig,
NumericalFeatureConfig,
)
from .categorical_configs import MultiCategoricalFeatureConfig
from .data_config_base import DataStorageConfig


Expand Down Expand Up @@ -62,6 +63,10 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig:
return MultiNumericalFeatureConfig(feature_dict)
elif transformation_name == "bucket-numerical":
return BucketNumericalFeatureConfig(feature_dict)
elif transformation_name == "categorical":
return FeatureConfig(feature_dict)
elif transformation_name == "multi-categorical":
return MultiCategoricalFeatureConfig(feature_dict)
else:
raise RuntimeError(f"Unknown transformation name: '{transformation_name}'")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
DistNumericalTransformation,
DistMultiNumericalTransformation,
DistBucketNumericalTransformation,
DistCategoryTransformation,
DistMultiCategoryTransformation,
)


Expand Down Expand Up @@ -51,6 +53,10 @@ def __init__(self, feature_config: FeatureConfig):
self.transformation = DistMultiNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "bucket-numerical":
self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict)
elif feat_type == "categorical":
self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict)
elif feat_type == "multi-categorical":
self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict)
else:
raise NotImplementedError(
f"Feature {feat_name} has type: {feat_type} that is not supported"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pyspark.sql import DataFrame, functions as F
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vectors
Expand Down Expand Up @@ -141,12 +141,27 @@ def get_transformation_name() -> str:
return "DistMultiCategoryTransformation"

def apply(self, input_df: DataFrame) -> DataFrame:
# Count category frequency
distinct_category_counts = (
input_df.select(self.multi_column)
.withColumn(
SINGLE_CATEGORY_COL, F.explode(F.split(F.col(self.multi_column), self.separator))
col_datatype = input_df.schema[self.multi_column].dataType
is_array_col = False
if col_datatype.typeName() == "array":
assert isinstance(col_datatype, ArrayType)
if not isinstance(col_datatype.elementType, StringType):
raise ValueError(
f"Unsupported array type {col_datatype.elementType} "
f"for column {self.multi_column}, expected StringType"
)

is_array_col = True

if is_array_col:
list_df = input_df.select(self.multi_column).alias(self.multi_column)
else:
list_df = input_df.select(
F.split(F.col(self.multi_column), self.separator).alias(self.multi_column)
)

distinct_category_counts = (
list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column)))
.groupBy(SINGLE_CATEGORY_COL)
.count()
)
Expand Down Expand Up @@ -206,11 +221,6 @@ def apply(self, input_df: DataFrame) -> DataFrame:
# The encoding for the missing category is an all-zeroes vector
category_map[MISSING_CATEGORY] = np.array([0] * len(valid_categories))

# Split tokens along separator to create List objects
token_list_df = input_df.select(
F.split(F.col(self.multi_column), self.separator).alias(self.multi_column)
)

# Use mapping to convert token list to a multi-hot vector by summing one-hot vectors
missing_vector = (
category_map[RARE_CATEGORY]
Expand Down Expand Up @@ -241,7 +251,7 @@ def token_list_to_multihot(token_list: Optional[List[str]]) -> Optional[List[flo
token_list_to_multihot, ArrayType(FloatType(), containsNull=False)
)

multihot_df = token_list_df.withColumn(
multihot_df = list_df.withColumn(
self.multi_column, token_list_to_multihot_udf(F.col(self.multi_column))
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,12 @@ def determine_spark_feature_type(feature_type: str) -> Type[DataType]:
In case an unsupported feature_type is provided.
"""
# TODO: Replace with pattern matching after moving to Python 3.10?
if feature_type in ["no-op", "multi-numerical"] or feature_type.startswith("text"):
if feature_type in [
"no-op",
"multi-numerical",
"categorical",
"multi-categorical",
] or feature_type.startswith("text"):
return StringType
if feature_type in ["numerical", "bucket-numerical", "none"]:
return FloatType
Expand Down
22 changes: 22 additions & 0 deletions graphstorm-processing/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter):
"feature_name": "rank_gauss2",
"transform": {"name": "rank_gauss", "epsilon": 0.1},
},
{
"feature_col": ["num_citations"],
"transform": {"name": "to_categorical", "mapping": {"1", "2", "3"}},
},
{
"feature_col": ["num_citations"],
"transform": {"name": "to_categorical", "separator": ","},
},
],
"labels": [
{"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]}
Expand Down Expand Up @@ -299,6 +307,20 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter):
"kwargs": {"epsilon": 0.1, "normalizer": "rank-gauss", "imputer": "none"},
},
},
{
"column": "num_citations",
"transformation": {
"name": "categorical",
"kwargs": {},
},
},
{
"column": "num_citations",
"transformation": {
"name": "multi-categorical",
"kwargs": {"separator": ","},
},
},
]
assert nodes_output["labels"] == [
{
Expand Down
79 changes: 78 additions & 1 deletion graphstorm-processing/tests/test_dist_category_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from typing import Tuple, Iterator
import os
import pytest
import pandas as pd
import tempfile

import mock
from numpy.testing import assert_array_equal
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructField, StructType, StringType
from pyspark.sql.types import StructField, StructType, StringType, ArrayType

from graphstorm_processing.data_transformations.dist_transformations import (
DistCategoryTransformation,
Expand Down Expand Up @@ -174,3 +176,78 @@ def test_multi_category_limited_categories(multi_cat_df_and_separator):
transformed_values = [row[col_name] for row in transformed_df.collect()]

assert_array_equal(expected_values, transformed_values)


def test_csv_input_categorical(spark: SparkSession, check_df_schema):
data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv")
long_vector_df = spark.read.csv(data_path, sep=",", header=True)
dist_categorical_transormation = DistCategoryTransformation(cols=["id"])

transformed_df = dist_categorical_transormation.apply(long_vector_df)
check_df_schema(transformed_df)
transformed_rows = transformed_df.collect()
expected_rows = [
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
]
for row, expected_row in zip(transformed_rows, expected_rows):
assert row["id"] == expected_row


def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema):
data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv")
long_vector_df = spark.read.csv(data_path, sep=",", header=True)
dist_categorical_transormation = DistMultiCategoryTransformation(cols=["feat"], separator=";")

transformed_df = dist_categorical_transormation.apply(long_vector_df)
check_df_schema(transformed_df)
transformed_rows = transformed_df.collect()
expected_rows = []
for i in range(5):
expected_rows.append([1] * 100)
for row, expected_row in zip(transformed_rows, expected_rows):
assert row["feat"] == expected_row


def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema):
# Define the schema for the DataFrame
schema = StructType([StructField("names", ArrayType(StringType()), True)])

# Sample data with arrays of strings
data = [
(["Alice", "Alicia"],),
(["Bob", "Bobby"],),
(["Cathy", "Catherine"],),
(["David", "Dave"],),
]

# Create a DataFrame using the sample data and the defined schema
df = spark.createDataFrame(data, schema)

with tempfile.TemporaryDirectory() as tmpdirname:
# Define the path for the Parquet file
parquet_path = f"{tmpdirname}/people_name.parquet"

# Write the DataFrame to a Parquet file
df.write.mode("overwrite").parquet(parquet_path)

# Read the Parquet file into a DataFrame
df_parquet = spark.read.parquet(parquet_path)

# Show the DataFrame loaded from the Parquet file
dist_categorical_transormation = DistMultiCategoryTransformation(cols=["names"], separator=None)

transformed_df = dist_categorical_transormation.apply(df_parquet)
check_df_schema(transformed_df)
transformed_rows = transformed_df.collect()
expected_rows = [
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
]
for row, expected_row in zip(transformed_rows, expected_rows):
assert row["names"] == expected_row

0 comments on commit 82f8e20

Please sign in to comment.