From 77e91ae33b5b81ed4b6f15ed3029812d89564c21 Mon Sep 17 00:00:00 2001 From: JalenCato Date: Wed, 8 Nov 2023 20:37:42 +0000 Subject: [PATCH] add test --- .../dist_category_transformation.py | 34 ++++++++---- .../test_dist_category_transformation.py | 55 +++++++++++++------ 2 files changed, 60 insertions(+), 29 deletions(-) diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py index e57ef0af09..1ae3060a8a 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py @@ -19,7 +19,7 @@ from pyspark.sql import DataFrame, functions as F from pyspark.sql.functions import when -from pyspark.sql.types import ArrayType, FloatType +from pyspark.sql.types import ArrayType, FloatType, StringType from pyspark.ml.feature import StringIndexer, OneHotEncoder from pyspark.ml.functions import vector_to_array from pyspark.ml.linalg import Vectors @@ -141,12 +141,27 @@ def get_transformation_name() -> str: return "DistMultiCategoryTransformation" def apply(self, input_df: DataFrame) -> DataFrame: - # Count category frequency - distinct_category_counts = ( - input_df.select(self.multi_column) - .withColumn( - SINGLE_CATEGORY_COL, F.explode(F.split(F.col(self.multi_column), self.separator)) + col_datatype = input_df.schema[self.multi_column].dataType + is_array_col = False + if col_datatype.typeName() == "array": + assert isinstance(col_datatype, ArrayType) + if not isinstance(col_datatype.elementType, StringType): + raise ValueError( + f"Unsupported array type {col_datatype.elementType} " + f"for column {self.multi_column}, expected StringType" + ) + + is_array_col = True + + if is_array_col: + list_df = input_df.select(self.multi_column).alias(self.multi_column) + else: + list_df = input_df.select( + F.split(F.col(self.multi_column), self.separator).alias(self.multi_column) ) + + distinct_category_counts = ( + list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column))) .groupBy(SINGLE_CATEGORY_COL) .count() ) @@ -206,11 +221,6 @@ def apply(self, input_df: DataFrame) -> DataFrame: # The encoding for the missing category is an all-zeroes vector category_map[MISSING_CATEGORY] = np.array([0] * len(valid_categories)) - # Split tokens along separator to create List objects - token_list_df = input_df.select( - F.split(F.col(self.multi_column), self.separator).alias(self.multi_column) - ) - # Use mapping to convert token list to a multi-hot vector by summing one-hot vectors missing_vector = ( category_map[RARE_CATEGORY] @@ -241,7 +251,7 @@ def token_list_to_multihot(token_list: Optional[List[str]]) -> Optional[List[flo token_list_to_multihot, ArrayType(FloatType(), containsNull=False) ) - multihot_df = token_list_df.withColumn( + multihot_df = list_df.withColumn( self.multi_column, token_list_to_multihot_udf(F.col(self.multi_column)) ) diff --git a/graphstorm-processing/tests/test_dist_category_transformation.py b/graphstorm-processing/tests/test_dist_category_transformation.py index f384015688..87ae1dfbb9 100644 --- a/graphstorm-processing/tests/test_dist_category_transformation.py +++ b/graphstorm-processing/tests/test_dist_category_transformation.py @@ -17,6 +17,7 @@ import os import pytest import pandas as pd +import shutil import mock from numpy.testing import assert_array_equal @@ -180,14 +181,18 @@ def test_multi_category_limited_categories(multi_cat_df_and_separator): def test_csv_input_categorical(spark: SparkSession, check_df_schema): data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") long_vector_df = spark.read.csv(data_path, sep=",", header=True) - dist_categorical_transormation = DistCategoryTransformation( - cols=["id"] - ) + dist_categorical_transormation = DistCategoryTransformation(cols=["id"]) transformed_df = dist_categorical_transormation.apply(long_vector_df) check_df_schema(transformed_df) transformed_rows = transformed_df.collect() - expected_rows = [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]] + expected_rows = [ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + ] for row, expected_row in zip(transformed_rows, expected_rows): assert row["id"] == expected_row @@ -195,9 +200,7 @@ def test_csv_input_categorical(spark: SparkSession, check_df_schema): def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") long_vector_df = spark.read.csv(data_path, sep=",", header=True) - dist_categorical_transormation = DistMultiCategoryTransformation( - cols=["feat"], separator=";" - ) + dist_categorical_transormation = DistMultiCategoryTransformation(cols=["feat"], separator=";") transformed_df = dist_categorical_transormation.apply(long_vector_df) check_df_schema(transformed_df) @@ -211,23 +214,41 @@ def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): # Define the schema for the DataFrame - schema = StructType([ - StructField("id", IntegerType(), True), - StructField("names", ArrayType(StringType()), True) - ]) + schema = StructType([StructField("names", ArrayType(StringType()), True)]) # Sample data with arrays of strings data = [ - (1, ["Alice", "Alicia"]), - (2, ["Bob", "Bobby"]), - (3, ["Cathy", "Catherine"]), - (4, ["David", "Dave"]) + (["Alice", "Alicia"],), + (["Bob", "Bobby"],), + (["Cathy", "Catherine"],), + (["David", "Dave"],), ] # Create a DataFrame using the sample data and the defined schema df = spark.createDataFrame(data, schema) + # Define the path for the Parquet file + parquet_path = "people_name.parquet" + # Write the DataFrame to a Parquet file - df.write.parquet('people_with_names_array.parquet') + df.write.mode("overwrite").parquet(parquet_path) + + # Read the Parquet file into a DataFrame + df_parquet = spark.read.parquet(parquet_path) + + # Show the DataFrame loaded from the Parquet file + dist_categorical_transormation = DistMultiCategoryTransformation(cols=["names"], separator=None) + + transformed_df = dist_categorical_transormation.apply(df_parquet) + check_df_schema(transformed_df) + transformed_rows = transformed_df.collect() + expected_rows = [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + for row, expected_row in zip(transformed_rows, expected_rows): + assert row["names"] == expected_row - shutil.rmtree('people_with_names_array.parquet') + shutil.rmtree(parquet_path)