Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Nov 8, 2023
1 parent a775818 commit 77e91ae
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pyspark.sql import DataFrame, functions as F
from pyspark.sql.functions import when
from pyspark.sql.types import ArrayType, FloatType
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.functions import vector_to_array
from pyspark.ml.linalg import Vectors
Expand Down Expand Up @@ -141,12 +141,27 @@ def get_transformation_name() -> str:
return "DistMultiCategoryTransformation"

def apply(self, input_df: DataFrame) -> DataFrame:
# Count category frequency
distinct_category_counts = (
input_df.select(self.multi_column)
.withColumn(
SINGLE_CATEGORY_COL, F.explode(F.split(F.col(self.multi_column), self.separator))
col_datatype = input_df.schema[self.multi_column].dataType
is_array_col = False
if col_datatype.typeName() == "array":
assert isinstance(col_datatype, ArrayType)
if not isinstance(col_datatype.elementType, StringType):
raise ValueError(
f"Unsupported array type {col_datatype.elementType} "
f"for column {self.multi_column}, expected StringType"
)

is_array_col = True

if is_array_col:
list_df = input_df.select(self.multi_column).alias(self.multi_column)
else:
list_df = input_df.select(
F.split(F.col(self.multi_column), self.separator).alias(self.multi_column)
)

distinct_category_counts = (
list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column)))
.groupBy(SINGLE_CATEGORY_COL)
.count()
)
Expand Down Expand Up @@ -206,11 +221,6 @@ def apply(self, input_df: DataFrame) -> DataFrame:
# The encoding for the missing category is an all-zeroes vector
category_map[MISSING_CATEGORY] = np.array([0] * len(valid_categories))

# Split tokens along separator to create List objects
token_list_df = input_df.select(
F.split(F.col(self.multi_column), self.separator).alias(self.multi_column)
)

# Use mapping to convert token list to a multi-hot vector by summing one-hot vectors
missing_vector = (
category_map[RARE_CATEGORY]
Expand Down Expand Up @@ -241,7 +251,7 @@ def token_list_to_multihot(token_list: Optional[List[str]]) -> Optional[List[flo
token_list_to_multihot, ArrayType(FloatType(), containsNull=False)
)

multihot_df = token_list_df.withColumn(
multihot_df = list_df.withColumn(
self.multi_column, token_list_to_multihot_udf(F.col(self.multi_column))
)

Expand Down
55 changes: 38 additions & 17 deletions graphstorm-processing/tests/test_dist_category_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pytest
import pandas as pd
import shutil

import mock
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -180,24 +181,26 @@ def test_multi_category_limited_categories(multi_cat_df_and_separator):
def test_csv_input_categorical(spark: SparkSession, check_df_schema):
data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv")
long_vector_df = spark.read.csv(data_path, sep=",", header=True)
dist_categorical_transormation = DistCategoryTransformation(
cols=["id"]
)
dist_categorical_transormation = DistCategoryTransformation(cols=["id"])

transformed_df = dist_categorical_transormation.apply(long_vector_df)
check_df_schema(transformed_df)
transformed_rows = transformed_df.collect()
expected_rows = [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]
expected_rows = [
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
]
for row, expected_row in zip(transformed_rows, expected_rows):
assert row["id"] == expected_row


def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema):
data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv")
long_vector_df = spark.read.csv(data_path, sep=",", header=True)
dist_categorical_transormation = DistMultiCategoryTransformation(
cols=["feat"], separator=";"
)
dist_categorical_transormation = DistMultiCategoryTransformation(cols=["feat"], separator=";")

transformed_df = dist_categorical_transormation.apply(long_vector_df)
check_df_schema(transformed_df)
Expand All @@ -211,23 +214,41 @@ def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema):

def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema):
# Define the schema for the DataFrame
schema = StructType([
StructField("id", IntegerType(), True),
StructField("names", ArrayType(StringType()), True)
])
schema = StructType([StructField("names", ArrayType(StringType()), True)])

# Sample data with arrays of strings
data = [
(1, ["Alice", "Alicia"]),
(2, ["Bob", "Bobby"]),
(3, ["Cathy", "Catherine"]),
(4, ["David", "Dave"])
(["Alice", "Alicia"],),
(["Bob", "Bobby"],),
(["Cathy", "Catherine"],),
(["David", "Dave"],),
]

# Create a DataFrame using the sample data and the defined schema
df = spark.createDataFrame(data, schema)

# Define the path for the Parquet file
parquet_path = "people_name.parquet"

# Write the DataFrame to a Parquet file
df.write.parquet('people_with_names_array.parquet')
df.write.mode("overwrite").parquet(parquet_path)

# Read the Parquet file into a DataFrame
df_parquet = spark.read.parquet(parquet_path)

# Show the DataFrame loaded from the Parquet file
dist_categorical_transormation = DistMultiCategoryTransformation(cols=["names"], separator=None)

transformed_df = dist_categorical_transormation.apply(df_parquet)
check_df_schema(transformed_df)
transformed_rows = transformed_df.collect()
expected_rows = [
[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
]
for row, expected_row in zip(transformed_rows, expected_rows):
assert row["names"] == expected_row

shutil.rmtree('people_with_names_array.parquet')
shutil.rmtree(parquet_path)

0 comments on commit 77e91ae

Please sign in to comment.