diff --git a/docs/source/gs-processing/developer/input-configuration.rst b/docs/source/gs-processing/developer/input-configuration.rst index 7cbb9039d9..8ce897bb56 100644 --- a/docs/source/gs-processing/developer/input-configuration.rst +++ b/docs/source/gs-processing/developer/input-configuration.rst @@ -426,6 +426,25 @@ arguments. by specifying a slide-window size ``s``, where ``s`` can an integer or float. GSProcessing then transforms each numeric value ``v`` of the property into a range from ``v - s/2`` through ``v + s/2`` , and assigns the value v to every bucket that the range covers. + +- ``categorical`` + + - Transforms values from a fixed list of possible values (categorical features) to a one-hot encoding. + The length of the resulting vector will be the number of categories in the data minus one, with a 1 in + the index of the single category, and zero everywhere else. + +.. note:: + The maximum number of categories in any categorical feature is 100. If a property has more than 100 categories of value, + only the most common 99 of them are placed in distinct categories, and the rest are placed in a special category named OTHER. + +- ``multi-categorical`` + + - Encodes vector-like data from a fixed list of possible values (i.e. multi-label/multi-categorical data) using a multi-hot encoding. The length of the resulting vector will be the number of categories in the data minus one, and each value will have a 1 value for every category that appears, and 0 everwhere else. + - ``kwargs``: + - ``separator`` (String, optional): Same as the one in the No-op operation, the separator is used to + split multiple input values for CSV files e.g. ``detective|noir``. If it is not provided, then the whole value + will be considered as an array. For Parquet files, if the input type is ArrayType(StringType()), then the + separator is ignored; if it is StringType(), it will apply same logic as in CSV. -------------- Examples diff --git a/graphstorm-processing/graphstorm_processing/config/categorical_configs.py b/graphstorm-processing/graphstorm_processing/config/categorical_configs.py new file mode 100644 index 0000000000..4438ccdf68 --- /dev/null +++ b/graphstorm-processing/graphstorm_processing/config/categorical_configs.py @@ -0,0 +1,37 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import Mapping +from .feature_config_base import FeatureConfig + + +class MultiCategoricalFeatureConfig(FeatureConfig): + """Feature configuration for multi-column categorical features. + + Supported kwargs + ---------------- + separator: str, optional + A separator to use when splitting a delimited string into multiple numerical values + as a vector. Only applicable to CSV input. Example: for a separator `'|'` the CSV + value `1|2|3` would be transformed to a vector, `[1, 2, 3]`. When `None` the expected + input format is an array of string values. + + """ + + def __init__(self, config: Mapping): + super().__init__(config) + self.separator = self._transformation_kwargs.get("separator", None) + + self._sanity_check() diff --git a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py index 89bfd568b9..7761acdcf8 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py +++ b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py @@ -125,6 +125,15 @@ def _convert_feature(feats: list[dict]) -> list[dict]: "normalizer": "rank-gauss", "imputer": "none", } + elif gconstruct_transform_dict["name"] == "to_categorical": + if "separator" in gconstruct_transform_dict: + gsp_transformation_dict["name"] = "multi-categorical" + gsp_transformation_dict["kwargs"] = { + "separator": gconstruct_transform_dict["separator"] + } + else: + gsp_transformation_dict["name"] = "categorical" + gsp_transformation_dict["kwargs"] = {} # TODO: Add support for other common transformations here else: raise ValueError( diff --git a/graphstorm-processing/graphstorm_processing/config/config_parser.py b/graphstorm-processing/graphstorm_processing/config/config_parser.py index 0193032b0c..3e32dbe91a 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_parser.py +++ b/graphstorm-processing/graphstorm_processing/config/config_parser.py @@ -26,6 +26,7 @@ MultiNumericalFeatureConfig, NumericalFeatureConfig, ) +from .categorical_configs import MultiCategoricalFeatureConfig from .data_config_base import DataStorageConfig @@ -62,6 +63,10 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig: return MultiNumericalFeatureConfig(feature_dict) elif transformation_name == "bucket-numerical": return BucketNumericalFeatureConfig(feature_dict) + elif transformation_name == "categorical": + return FeatureConfig(feature_dict) + elif transformation_name == "multi-categorical": + return MultiCategoricalFeatureConfig(feature_dict) else: raise RuntimeError(f"Unknown transformation name: '{transformation_name}'") diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py index aeab3830fb..e325e67243 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py @@ -24,6 +24,8 @@ DistNumericalTransformation, DistMultiNumericalTransformation, DistBucketNumericalTransformation, + DistCategoryTransformation, + DistMultiCategoryTransformation, ) @@ -51,6 +53,10 @@ def __init__(self, feature_config: FeatureConfig): self.transformation = DistMultiNumericalTransformation(**default_kwargs, **args_dict) elif feat_type == "bucket-numerical": self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict) + elif feat_type == "categorical": + self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict) + elif feat_type == "multi-categorical": + self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict) else: raise NotImplementedError( f"Feature {feat_name} has type: {feat_type} that is not supported" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py index e57ef0af09..1ae3060a8a 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py @@ -19,7 +19,7 @@ from pyspark.sql import DataFrame, functions as F from pyspark.sql.functions import when -from pyspark.sql.types import ArrayType, FloatType +from pyspark.sql.types import ArrayType, FloatType, StringType from pyspark.ml.feature import StringIndexer, OneHotEncoder from pyspark.ml.functions import vector_to_array from pyspark.ml.linalg import Vectors @@ -141,12 +141,27 @@ def get_transformation_name() -> str: return "DistMultiCategoryTransformation" def apply(self, input_df: DataFrame) -> DataFrame: - # Count category frequency - distinct_category_counts = ( - input_df.select(self.multi_column) - .withColumn( - SINGLE_CATEGORY_COL, F.explode(F.split(F.col(self.multi_column), self.separator)) + col_datatype = input_df.schema[self.multi_column].dataType + is_array_col = False + if col_datatype.typeName() == "array": + assert isinstance(col_datatype, ArrayType) + if not isinstance(col_datatype.elementType, StringType): + raise ValueError( + f"Unsupported array type {col_datatype.elementType} " + f"for column {self.multi_column}, expected StringType" + ) + + is_array_col = True + + if is_array_col: + list_df = input_df.select(self.multi_column).alias(self.multi_column) + else: + list_df = input_df.select( + F.split(F.col(self.multi_column), self.separator).alias(self.multi_column) ) + + distinct_category_counts = ( + list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column))) .groupBy(SINGLE_CATEGORY_COL) .count() ) @@ -206,11 +221,6 @@ def apply(self, input_df: DataFrame) -> DataFrame: # The encoding for the missing category is an all-zeroes vector category_map[MISSING_CATEGORY] = np.array([0] * len(valid_categories)) - # Split tokens along separator to create List objects - token_list_df = input_df.select( - F.split(F.col(self.multi_column), self.separator).alias(self.multi_column) - ) - # Use mapping to convert token list to a multi-hot vector by summing one-hot vectors missing_vector = ( category_map[RARE_CATEGORY] @@ -241,7 +251,7 @@ def token_list_to_multihot(token_list: Optional[List[str]]) -> Optional[List[flo token_list_to_multihot, ArrayType(FloatType(), containsNull=False) ) - multihot_df = token_list_df.withColumn( + multihot_df = list_df.withColumn( self.multi_column, token_list_to_multihot_udf(F.col(self.multi_column)) ) diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py b/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py index 19e28be399..c124760da6 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py @@ -92,7 +92,12 @@ def determine_spark_feature_type(feature_type: str) -> Type[DataType]: In case an unsupported feature_type is provided. """ # TODO: Replace with pattern matching after moving to Python 3.10? - if feature_type in ["no-op", "multi-numerical"] or feature_type.startswith("text"): + if feature_type in [ + "no-op", + "multi-numerical", + "categorical", + "multi-categorical", + ] or feature_type.startswith("text"): return StringType if feature_type in ["numerical", "bucket-numerical", "none"]: return FloatType diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index 5bad285451..7afb0df4b2 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -227,6 +227,14 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter): "feature_name": "rank_gauss2", "transform": {"name": "rank_gauss", "epsilon": 0.1}, }, + { + "feature_col": ["num_citations"], + "transform": {"name": "to_categorical", "mapping": {"1", "2", "3"}}, + }, + { + "feature_col": ["num_citations"], + "transform": {"name": "to_categorical", "separator": ","}, + }, ], "labels": [ {"label_col": "label", "task_type": "classification", "split_pct": [0.8, 0.1, 0.1]} @@ -299,6 +307,20 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter): "kwargs": {"epsilon": 0.1, "normalizer": "rank-gauss", "imputer": "none"}, }, }, + { + "column": "num_citations", + "transformation": { + "name": "categorical", + "kwargs": {}, + }, + }, + { + "column": "num_citations", + "transformation": { + "name": "multi-categorical", + "kwargs": {"separator": ","}, + }, + }, ] assert nodes_output["labels"] == [ { diff --git a/graphstorm-processing/tests/test_dist_category_transformation.py b/graphstorm-processing/tests/test_dist_category_transformation.py index 2f144fb007..662fe32a1a 100644 --- a/graphstorm-processing/tests/test_dist_category_transformation.py +++ b/graphstorm-processing/tests/test_dist_category_transformation.py @@ -16,11 +16,13 @@ from typing import Tuple, Iterator import os import pytest +import pandas as pd +import tempfile import mock from numpy.testing import assert_array_equal from pyspark.sql import SparkSession, DataFrame -from pyspark.sql.types import StructField, StructType, StringType +from pyspark.sql.types import StructField, StructType, StringType, ArrayType from graphstorm_processing.data_transformations.dist_transformations import ( DistCategoryTransformation, @@ -174,3 +176,78 @@ def test_multi_category_limited_categories(multi_cat_df_and_separator): transformed_values = [row[col_name] for row in transformed_df.collect()] assert_array_equal(expected_values, transformed_values) + + +def test_csv_input_categorical(spark: SparkSession, check_df_schema): + data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") + long_vector_df = spark.read.csv(data_path, sep=",", header=True) + dist_categorical_transormation = DistCategoryTransformation(cols=["id"]) + + transformed_df = dist_categorical_transormation.apply(long_vector_df) + check_df_schema(transformed_df) + transformed_rows = transformed_df.collect() + expected_rows = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] + for row, expected_row in zip(transformed_rows, expected_rows): + assert row["id"] == expected_row + + +def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): + data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") + long_vector_df = spark.read.csv(data_path, sep=",", header=True) + dist_categorical_transormation = DistMultiCategoryTransformation(cols=["feat"], separator=";") + + transformed_df = dist_categorical_transormation.apply(long_vector_df) + check_df_schema(transformed_df) + transformed_rows = transformed_df.collect() + expected_rows = [] + for i in range(5): + expected_rows.append([1] * 100) + for row, expected_row in zip(transformed_rows, expected_rows): + assert row["feat"] == expected_row + + +def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): + # Define the schema for the DataFrame + schema = StructType([StructField("names", ArrayType(StringType()), True)]) + + # Sample data with arrays of strings + data = [ + (["Alice", "Alicia"],), + (["Bob", "Bobby"],), + (["Cathy", "Catherine"],), + (["David", "Dave"],), + ] + + # Create a DataFrame using the sample data and the defined schema + df = spark.createDataFrame(data, schema) + + with tempfile.TemporaryDirectory() as tmpdirname: + # Define the path for the Parquet file + parquet_path = f"{tmpdirname}/people_name.parquet" + + # Write the DataFrame to a Parquet file + df.write.mode("overwrite").parquet(parquet_path) + + # Read the Parquet file into a DataFrame + df_parquet = spark.read.parquet(parquet_path) + + # Show the DataFrame loaded from the Parquet file + dist_categorical_transormation = DistMultiCategoryTransformation(cols=["names"], separator=None) + + transformed_df = dist_categorical_transormation.apply(df_parquet) + check_df_schema(transformed_df) + transformed_rows = transformed_df.collect() + expected_rows = [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + for row, expected_row in zip(transformed_rows, expected_rows): + assert row["names"] == expected_row