From 3ac5839a85552c7e555d7355b0d01e95991866a4 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Mon, 27 May 2024 06:58:07 +0000 Subject: [PATCH] [GSProcessing] Add structure for saving transformation JSON files. First implemented for categorical transformation. This commit only adds saving the categorical transformation in a JSON representation. --- .../graphstorm_processing/constants.py | 3 + .../dist_feature_transformer.py | 32 +++- .../base_dist_transformation.py | 31 +++- .../dist_category_transformation.py | 98 ++++++++++-- .../distributed_executor.py | 45 ++++-- .../dist_heterogeneous_loader.py | 144 ++++++++++++------ graphstorm-processing/tests/conftest.py | 3 +- .../gsprocessing-config.json | 10 +- .../small_heterogeneous_graph/nodes/user.csv | 12 +- .../test_dist_category_transformation.py | 50 ++++-- .../tests/test_dist_heterogenous_loader.py | 68 ++++++--- 11 files changed, 368 insertions(+), 128 deletions(-) diff --git a/graphstorm-processing/graphstorm_processing/constants.py b/graphstorm-processing/graphstorm_processing/constants.py index 13b7af65ce..976ab57818 100644 --- a/graphstorm-processing/graphstorm_processing/constants.py +++ b/graphstorm-processing/graphstorm_processing/constants.py @@ -85,3 +85,6 @@ class FilesystemType(Enum): "3.4": "3.3.4", "3.3": "3.3.2", } + +########## Precomputed transformations ################ +TRANSFORMATIONS_FILENAME = "precomputed_transformations.json" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py index 5527960b6a..9ef2fd0a33 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py @@ -16,7 +16,7 @@ import logging -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession from graphstorm_processing.config.feature_config_base import FeatureConfig from .dist_transformations import ( @@ -37,13 +37,19 @@ class DistFeatureTransformer(object): which can then be be applied through a call to apply_transformation. """ - def __init__(self, feature_config: FeatureConfig): + def __init__( + self, feature_config: FeatureConfig, spark: SparkSession, json_representation: dict + ): feat_type = feature_config.feat_type feat_name = feature_config.feat_name args_dict = feature_config.transformation_kwargs self.transformation: DistributedTransformation + # TODO: We will use this to re-apply transformations + self.json_representation = json_representation - default_kwargs = {"cols": feature_config.cols} + default_kwargs = { + "cols": feature_config.cols, + } logging.info("Feature name: %s", feat_name) logging.info("Transformation type: %s", feat_type) @@ -56,7 +62,9 @@ def __init__(self, feature_config: FeatureConfig): elif feat_type == "bucket-numerical": self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict) elif feat_type == "categorical": - self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict) + self.transformation = DistCategoryTransformation( + **default_kwargs, **args_dict, spark=spark + ) elif feat_type == "multi-categorical": self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict) elif feat_type == "huggingface": @@ -66,14 +74,24 @@ def __init__(self, feature_config: FeatureConfig): f"Feature {feat_name} has type: {feat_type} that is not supported" ) - def apply_transformation(self, input_df: DataFrame) -> DataFrame: + def apply_transformation(self, input_df: DataFrame) -> tuple[DataFrame, dict]: """ - Given an input dataframe, select only the relevant columns + Given an input DataFrame, select only the relevant columns and apply the expected transformation to them. + + Returns + ------- + tuple[DataFrame, dict] + A tuple with two items, the first is the transformed input DataFrame, + the second is a JSON representation of the transformation. This will + allow us to apply the same transformation to new data. """ input_df = input_df.select(self.transformation.cols) # type: ignore - return self.transformation.apply(input_df) + return ( + self.transformation.apply(input_df), + self.transformation.get_json_representation(), + ) def get_transformation_name(self) -> str: """ diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py index 21b31e1154..221bdce0f7 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py @@ -15,9 +15,9 @@ """ from abc import ABC, abstractmethod -from typing import Sequence +from typing import Optional, Sequence -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession class DistributedTransformation(ABC): @@ -25,19 +25,34 @@ class DistributedTransformation(ABC): Base class for all distributed transformations. """ - def __init__(self, cols: Sequence[str]) -> None: + def __init__( + self, + cols: Sequence[str], + spark: Optional[SparkSession] = None, + json_representation: Optional[dict] = None, + ) -> None: self.cols = cols + self.spark = spark + self.json_representation = json_representation @abstractmethod def apply(self, input_df: DataFrame) -> DataFrame: """ - Applies the transformation to the input DataFrame. - The returned dataframe will only contain the columns specified during initialization. + Applies the transformation to the input DataFrame, and returns the modified + DataFrame. + + The returned DataFrame will only contain the columns specified during initialization. """ + def get_json_representation(self) -> dict: + """Get a JSON representation of the transformation.""" + # TODO: Should we try to guarantee apply() has ran before this? + if self.json_representation: + return self.json_representation + else: + return {} + @staticmethod @abstractmethod def get_transformation_name() -> str: - """ - Get the name of the transformation - """ + """Get the name of the transformation.""" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py index b03a0c9333..440b78bffd 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py @@ -14,11 +14,12 @@ limitations under the License. """ -from typing import Dict, List, Optional, Sequence +from typing import List, Optional, Sequence import numpy as np +import pandas as pd -from pyspark.sql import DataFrame, functions as F +from pyspark.sql import DataFrame, functions as F, SparkSession from pyspark.sql.functions import when from pyspark.sql.types import ArrayType, FloatType, StringType from pyspark.ml.feature import StringIndexer, OneHotEncoder @@ -40,8 +41,8 @@ class DistCategoryTransformation(DistributedTransformation): Transforms categorical features into a vector of one-hot-encoded values. """ - def __init__(self, cols: List[str]) -> None: - super().__init__(cols) + def __init__(self, cols: List[str], spark: SparkSession) -> None: + super().__init__(cols, spark) @staticmethod def get_transformation_name() -> str: @@ -49,9 +50,10 @@ def get_transformation_name() -> str: def apply(self, input_df: DataFrame) -> DataFrame: processed_col_names = [] - for col in self.cols: - processed_col_names.append(col + "_processed") - distinct_category_counts = input_df.groupBy(col).count() # type: DataFrame + top_categories_per_col: dict[str, list] = {} + for current_col in self.cols: + processed_col_names.append(current_col + "_processed") + distinct_category_counts = input_df.groupBy(current_col).count() # type: DataFrame num_distinct_categories = distinct_category_counts.count() # Conditionally replace rare categories with single placeholder @@ -60,17 +62,23 @@ def apply(self, input_df: DataFrame) -> DataFrame: MAX_CATEGORIES_PER_FEATURE - 1 ) top_categories_set = {row[0] for row in top_categories} + top_categories_per_col[current_col] = list(top_categories_set) # TODO: Ideally we don't want to use withColumn in a loop input_df = input_df.withColumn( - col, - when(input_df[col].isin(top_categories_set), input_df[col]).otherwise( - RARE_CATEGORY - ), + current_col, + when( + input_df[current_col].isin(top_categories_set), input_df[current_col] + ).otherwise(RARE_CATEGORY), ) + else: + top_categories_per_col[current_col] = [ + x[current_col] for x in distinct_category_counts.select(current_col).collect() + ] # Replace empty string cols with None input_df = input_df.withColumn( - col, when(input_df[col] == "", None).otherwise(input_df[col]) + current_col, + when(input_df[current_col] == "", None).otherwise(input_df[current_col]), ) # We first convert the strings to float indexes @@ -105,8 +113,74 @@ def apply(self, input_df: DataFrame) -> DataFrame: ] ) + # Structure: {column_name: {category_string: index_value, ...}. ...} + per_col_label_to_one_hot_idx: dict[str, dict[str, int]] = {} + + # To get the transformed values for each value in each col + # we need to create a DataFrame with the top categories for the current + # col, then fill in the rest of the values with placeholders + # and pass the generated DF through the one-hot encoder + for current_col, processed_col in zip(self.cols, processed_col_names): + other_cols = [x for x in self.cols if x != current_col] + top_str_categories_list = top_categories_per_col[current_col] + # Spark doesn't model missing values, the all-zeroes vector is used + top_str_categories_list.remove(None) + print(top_str_categories_list) + # Each col might have different number of top categories, we need one DF per col + num_current_col_cats = len(top_str_categories_list) + # We don't care about values for the other cols in this iteration, + # just fill with empty string + placeholder_vals = [""] * num_current_col_cats + placeholder_cols = [placeholder_vals for _ in range(len(self.cols) - 1)] + current_col_unique_vals = [list(top_str_categories_list)] + # We need to create a DF where all cols have num_rows == num_current_col_cats + # and the current col needs to be the first col in the DF. + vals_dict = dict( + zip([current_col] + other_cols, current_col_unique_vals + placeholder_cols) + ) + + # One hot encoder expects a DF with all cols that were used to train it + # so we use the top-MAX_CATEGORIES_PER_FEATURE values for the current col, + # and the placeholders for the rest + top_str_categories_df = self.spark.createDataFrame(pd.DataFrame(vals_dict)) + top_indexed_categories_df = str_indexer_model.transform(top_str_categories_df) + + # For the current col, get the one-hot index for each of its category strings + # by passing the top-k values DF through the one-hot encoder model + per_col_label_to_one_hot_idx[current_col] = { + x[current_col]: int(x[processed_col]) + for x in one_hot_encoder_model.transform(top_indexed_categories_df).collect() + } + + # see get_json_representation() docstring for structure + self.json_representation = { + "string_indexer_labels_array": str_indexer_model.labelsArray, + "cols": self.cols, + "per_col_label_to_one_hot_idx": per_col_label_to_one_hot_idx, + } + return dense_vector_features + def get_json_representation(self) -> dict: + """Representation of the single-category transformation for one or more columns. + + Returns + ------- + dict + Structure: + string_indexer_labels_array: + tuple[tuple[str]], outer tuple has num_cols elements, + each inner tuple has num_cats elements, each str is a category string. + Spark uses this to represent the one-hot index for each category, its + position in the inner tuple is the one-hot-index position for the string. + Categories are sorted by their frequency in the data. + cols: + list[str], with num_cols elements + per_col_label_to_one_hot_idx: + dict[str, dict[str, int]], with num_cols elements, each with num_categories elements + """ + return self.json_representation + class DistMultiCategoryTransformation(DistributedTransformation): """ diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index d97b9ff4b0..2bc7356287 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -63,9 +63,11 @@ from graphstorm_processing.graph_loaders.dist_heterogeneous_loader import ( DistHeterogeneousGraphLoader, + HeterogeneousLoaderConfig, ) from graphstorm_processing.config.config_parser import create_config_objects from graphstorm_processing.config.config_conversion import GConstructConfigConverter +from graphstorm_processing.constants import TRANSFORMATIONS_FILENAME from graphstorm_processing.data_transformations import spark_utils, s3_utils from graphstorm_processing.repartition_files import ( repartition_files, @@ -190,7 +192,17 @@ def __init__( with open(graph_conf, "r", encoding="utf-8") as f: dataset_config_dict: Dict[str, Any] = json.load(f) - # Use appropriate config parser depending on file version + # Load the pre-computed transformations if the file exists + if os.path.exists(os.path.join(self.local_config_path, TRANSFORMATIONS_FILENAME)): + with open( + os.path.join(self.local_config_path, TRANSFORMATIONS_FILENAME), + "r", + encoding="utf-8", + ) as f: + self.precomputed_transformations = json.load(f) + else: + self.precomputed_transformations = {} + if "version" in dataset_config_dict: config_version = dataset_config_dict["version"] if config_version == "gsprocessing-v1.0": @@ -259,18 +271,22 @@ def run(self) -> None: data_configs = create_config_objects(self.graph_config_dict) t0 = time.time() - # Prefer explicit arguments for clarity - loader = DistHeterogeneousGraphLoader( - spark=self.spark, - local_input_path=self.local_config_path, - local_output_path=self.local_output_path, - data_configs=data_configs, - input_prefix=self.input_prefix, - output_prefix=self.output_prefix, - num_output_files=self.num_output_files, + loader_config = HeterogeneousLoaderConfig( add_reverse_edges=self.add_reverse_edges, + data_configs=data_configs, enable_assertions=False, graph_name=self.graph_name, + input_prefix=self.input_prefix, + local_input_path=self.local_config_path, + local_output_path=self.local_output_path, + num_output_files=self.num_output_files, + output_prefix=self.output_prefix, + precomputed_transformations=self.precomputed_transformations, + ) + # Prefer explicit arguments for clarity + loader = DistHeterogeneousGraphLoader( + self.spark, + loader_config, ) graph_meta_dict, timers_dict = loader.load() @@ -474,6 +490,15 @@ def main(): f"s3://{input_bucket}/{input_s3_prefix}/" f"{gsprocessing_args.config_filename}" ) from e + # Try to download the pre-computed transformations file, if it exists + try: + s3.download_file( + input_bucket, + f"{input_s3_prefix}/{TRANSFORMATIONS_FILENAME}", + os.path.join(tempdir.name, TRANSFORMATIONS_FILENAME), + ) + except botocore.exceptions.ClientError as _: + pass local_config_path = tempdir.name # local output location for metadata files diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 79fc14d454..23fe5a6446 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -20,6 +20,7 @@ import numbers import os from collections import Counter, defaultdict +from dataclasses import dataclass from time import perf_counter from typing import Any, Dict, Mapping, Optional, Sequence, Set, Tuple @@ -45,6 +46,7 @@ SPECIAL_CHARACTERS, HUGGINGFACE_TRANFORM, HUGGINGFACE_TOKENIZE, + TRANSFORMATIONS_FILENAME, ) from ..config.config_parser import EdgeConfig, NodeConfig, StructureConfig from ..config.label_config_base import LabelConfig @@ -64,68 +66,90 @@ NODE_MAPPING_INT = "new" -class DistHeterogeneousGraphLoader(HeterogeneousGraphLoader): +@dataclass +class HeterogeneousLoaderConfig: """ - A graph loader designed to run distributed processing of a heterogeneous graph. + Configuration object for the loader. - Parameters - ---------- - spark : SparkSession - The SparkSession that we use to perform the processing - local_input_path : str - Local path to input configuration data - local_metadata_path : str - Local path to where the output metadata files will be created. + add_reverse_edges : bool + Whether to add reverse edges to the graph. data_configs : Dict[str, Sequence[StructureConfig]] Dictionary of node and edge configurations objects. + enable_assertions : bool, optional + When true enables sanity checks for the output created. + However these are costly to compute, so we disable them by default. + graph_name: str + The name of the graph we will process. input_prefix : str The prefix to the input data. Can be an S3 URI or an **absolute** local path. + local_input_path : str + Local path to input configuration data + local_output_path : str + Local path to where the output metadata files will be created. + num_output_files : int + The number of files (partitions) to create for the output, if None we + let Spark decide. output_prefix : str The prefix to where the output data will be created. Can be an S3 URI or an **absolute** local path. - num_output_files : Optional[int], optional - The number of files (partitions) to create for the output, if None we - let Spark decide. - enable_assertions : bool, optional - When true enables sanity checks for the output created. - However these are costly to compute, so we disable them by default. + precomputed_transformations: dict + A dictionary describing precomputed transformations for the features + of the graph. + """ + + add_reverse_edges: bool + data_configs: Dict[str, Sequence[StructureConfig]] + enable_assertions: bool + graph_name: str + input_prefix: str + local_input_path: str + local_output_path: str + num_output_files: int + output_prefix: str + precomputed_transformations: dict + + +class DistHeterogeneousGraphLoader(HeterogeneousGraphLoader): + """ + A graph loader designed to run distributed processing of a heterogeneous graph. + + Parameters + ---------- + spark : SparkSession + The SparkSession that we use to perform the processing """ def __init__( self, spark: SparkSession, - local_input_path: str, - local_output_path: str, - data_configs: Dict[str, Sequence[StructureConfig]], - input_prefix: str, - output_prefix: str, - num_output_files: Optional[int] = None, - add_reverse_edges=True, - enable_assertions=False, - graph_name: Optional[str] = None, + loader_config: HeterogeneousLoaderConfig, ): - super().__init__(local_input_path, local_output_path, data_configs) + super().__init__( + loader_config.local_input_path, + loader_config.local_output_path, + loader_config.data_configs, + ) # TODO: Pass as an argument? - if input_prefix.startswith("s3://"): + if loader_config.input_prefix.startswith("s3://"): self.filesystem_type = "s3" else: - assert os.path.isabs(input_prefix), "We expect an absolute path" + assert os.path.isabs(loader_config.input_prefix), "We expect an absolute path" self.filesystem_type = "local" self.spark = spark # type: SparkSession - self.add_reverse_edges = add_reverse_edges + self.add_reverse_edges = loader_config.add_reverse_edges # Remove trailing slash in s3 paths if self.filesystem_type == "s3": - self.input_prefix = s3_utils.s3_path_remove_trailing(input_prefix) - self.output_prefix = s3_utils.s3_path_remove_trailing(output_prefix) + self.input_prefix = s3_utils.s3_path_remove_trailing(loader_config.input_prefix) + self.output_prefix = s3_utils.s3_path_remove_trailing(loader_config.output_prefix) else: # TODO: Any checks for local paths? - self.input_prefix = input_prefix - self.output_prefix = output_prefix + self.input_prefix = loader_config.input_prefix + self.output_prefix = loader_config.output_prefix self.num_output_files = ( - num_output_files - if num_output_files and num_output_files > 0 + loader_config.num_output_files + if loader_config.num_output_files and loader_config.num_output_files > 0 else int(spark.sparkContext.defaultParallelism) ) assert self.num_output_files > 0 @@ -134,14 +158,16 @@ def __init__( # Mapping from label name to value counts self.label_properties: Dict[str, Counter] = defaultdict(Counter) self.timers = defaultdict(float) # type: Dict - self.enable_assertions = enable_assertions + self.enable_assertions = loader_config.enable_assertions # Column names that are valid in CSV can be invalid in Parquet # we collect an column name substitutions we had to make # here and later output as a JSON file. self.column_substitutions = {} # type: Dict[str, str] self.graph_info = {} # type: Dict[str, Any] - self.graph_name = graph_name + self.transformation_representations = {"node_features": {}, "edge_features": {}} + self.graph_name = loader_config.graph_name self.skip_train_masks = False + self.pre_computed_transformations = loader_config.precomputed_transformations def process_and_write_graph_data( self, data_configs: Mapping[str, Sequence[StructureConfig]] @@ -250,6 +276,14 @@ def process_and_write_graph_data( with open(os.path.join(self.output_path, "metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata_dict, f, indent=4) + # Write the transformations file + with open( + os.path.join(self.output_path, TRANSFORMATIONS_FILENAME), "w", encoding="utf-8" + ) as f: + json.dump(self.transformation_representations, f, indent=4) + + # Column substitutions contain any col names we needed to change because their original + # name did not fit Parquet requirements if len(self.column_substitutions) > 0: with open( os.path.join(self.output_path, "column_substitutions.json"), "w", encoding="utf-8" @@ -949,7 +983,7 @@ def _process_node_features( Returns ------- Tuple[Dict, Dict] - A tuple with two dicts, both dicts have names as their keys. + A tuple with two dicts, both dicts have feature names as their keys. The first dict has the output metadata for the feature as values, and the second has the lengths of the vector representations of the features as values. @@ -957,12 +991,24 @@ def _process_node_features( node_type_feature_metadata = {} ntype_feat_sizes = {} # type: Dict[str, int] for feat_conf in feature_configs: - transformer = DistFeatureTransformer(feat_conf) + json_representation = self.transformation_representations.get(feat_conf.feat_name, {}) + transformer = DistFeatureTransformer(feat_conf, self.spark, json_representation) - transformed_feature_df = transformer.apply_transformation(nodes_df) + transformed_feature_df, json_representation = transformer.apply_transformation(nodes_df) transformed_feature_df.cache() - def write_processed_feature(feat_name, single_feature_df, node_type, transformer_name): + # Will evaluate False for empty dict + if json_representation: + self.transformation_representations["node_features"][ + feat_conf.feat_name + ] = json_representation + + def write_processed_feature( + feat_name: str, + single_feature_df: DataFrame, + node_type: str, + transformer: DistFeatureTransformer, + ): feature_output_path = os.path.join( self.output_prefix, f"node_data/{node_type}-{feat_name}" ) @@ -983,7 +1029,7 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer nfeat_size = 1 if isinstance(feat_val, (int, float)) else len(feat_val) ntype_feat_sizes.update({feat_name: nfeat_size}) - self.timers[f"{transformer_name}-{node_type}-{feat_name}"] = ( + self.timers[f"{transformer.get_transformation_name()}-{node_type}-{feat_name}"] = ( perf_counter() - node_transformation_start ) @@ -1001,7 +1047,7 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer bert_feat_name, single_feature_df, node_type, - transformer.get_transformation_name(), + transformer, ) else: single_feature_df = transformed_feature_df.select(feat_col).withColumnRenamed( @@ -1011,7 +1057,7 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer feat_name, single_feature_df, node_type, - transformer.get_transformation_name(), + transformer, ) # Unpersist and move on to next feature @@ -1416,10 +1462,16 @@ def _process_edge_features( feat_conf.feat_name, edge_type, ) - transformer = DistFeatureTransformer(feat_conf) + json_representation = self.transformation_representations.get(feat_conf.feat_name, {}) + transformer = DistFeatureTransformer(feat_conf, self.spark, json_representation) - transformed_feature_df = transformer.apply_transformation(edges_df) + transformed_feature_df, json_representation = transformer.apply_transformation(edges_df) transformed_feature_df.cache() + # Will evaluate False for empty dict + if json_representation: + self.transformation_representations["node_features"][ + feat_conf.feat_name + ] = json_representation def write_feature(self, feat_name, single_feature_df, edge_type, transformer_name): feature_output_path = os.path.join( diff --git a/graphstorm-processing/tests/conftest.py b/graphstorm-processing/tests/conftest.py index f34f1418f2..3be8cd2f68 100644 --- a/graphstorm-processing/tests/conftest.py +++ b/graphstorm-processing/tests/conftest.py @@ -18,6 +18,7 @@ import sys import logging import tempfile +from typing import Iterator import numpy as np import pytest @@ -61,7 +62,7 @@ def temp_output_root(): @pytest.fixture(scope="session", name="spark") -def spark_fixture(): +def spark_fixture() -> Iterator[SparkSession]: """Create the main SparkContext we use throughout the tests""" spark_context = ( SparkSession.builder.master("local[4]") diff --git a/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json b/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json index 7967f5089b..56212b540b 100644 --- a/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json +++ b/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json @@ -57,16 +57,22 @@ } } }, - { + { "column": "occupation", "transformation": { "name": "huggingface", "kwargs": { "action": "tokenize_hf", "hf_model": "bert-base-uncased", - "max_seq_length":16 + "max_seq_length": 16 } } + }, + { + "column": "state", + "transformation": { + "name": "categorical" + } } ], "labels": [ diff --git a/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv b/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv index 5f958fa455..4bf7865c3d 100644 --- a/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv +++ b/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv @@ -1,6 +1,6 @@ -~id,age,occupation,gender,salary,multi -mark,30,actor,male,10000,1|2 -john,22,student,male,,3|4 -tara,33,lawyer,female,30000,5|6 -kate,29,doctor,,35000,7|8 -george,22,student,male,0,9|10 \ No newline at end of file +~id,age,occupation,gender,salary,multi,state +mark,30,actor,male,10000,1|2,wa +john,22,student,male,,3|4,ca +tara,33,lawyer,female,30000,5|6,wa +kate,29,doctor,,35000,7|8, +george,22,student,male,0,9|10,ny \ No newline at end of file diff --git a/graphstorm-processing/tests/test_dist_category_transformation.py b/graphstorm-processing/tests/test_dist_category_transformation.py index 51a5662013..7b082c7520 100644 --- a/graphstorm-processing/tests/test_dist_category_transformation.py +++ b/graphstorm-processing/tests/test_dist_category_transformation.py @@ -16,12 +16,11 @@ from typing import Tuple, Iterator import os -import pytest -import pandas as pd import tempfile import mock from numpy.testing import assert_array_equal +import pytest from pyspark.sql import SparkSession, DataFrame from pyspark.sql.types import StructField, StructType, StringType, ArrayType @@ -40,7 +39,7 @@ def multi_cat_df_and_separator_fixture( spark: SparkSession, separator="," ) -> Iterator[Tuple[DataFrame, str]]: - """Gneerate multi-category df, yields the DF and its separator""" + """Generate multi-category df, yields the DF and its separator""" data = [ (f"Actor{separator}Director",), (f"Director{separator}Writer",), @@ -63,9 +62,9 @@ def multi_cat_df_and_separator_fixture( ), 3, ) -def test_limited_category_transformation(user_df): +def test_limited_category_transformation(user_df: DataFrame, spark: SparkSession): """Test single-cat transformation with limited categories""" - dist_category_transformation = DistCategoryTransformation(["occupation"]) + dist_category_transformation = DistCategoryTransformation(["occupation"], spark) transformed_df = dist_category_transformation.apply(user_df) group_counts = ( @@ -77,9 +76,9 @@ def test_limited_category_transformation(user_df): assert row["count"] == expected_count -def test_all_categories_transformation(user_df, check_df_schema): +def test_all_categories_transformation(user_df, check_df_schema, spark): """Test single-cat transformation with all categories""" - dist_category_transformation = DistCategoryTransformation(["occupation"]) + dist_category_transformation = DistCategoryTransformation(["occupation"], spark) transformed_df = dist_category_transformation.apply(user_df) @@ -102,7 +101,7 @@ def test_category_transformation_with_null_values(spark: SparkSession): columns = ["name", "occupation", "gender"] df = spark.createDataFrame(data, schema=columns) - dist_category_transformation = DistCategoryTransformation(["occupation"]) + dist_category_transformation = DistCategoryTransformation(["occupation"], spark) transformed_df = dist_category_transformation.apply(df) @@ -111,9 +110,9 @@ def test_category_transformation_with_null_values(spark: SparkSession): assert transformed_distinct_values == 4 -def test_multiple_categories_transformation(user_df): - """Test transforming multiple cat columns""" - dist_category_transformation = DistCategoryTransformation(["occupation", "gender"]) +def test_multiple_categories_transformation(user_df, spark): + """Test transforming multiple single-cat columns""" + dist_category_transformation = DistCategoryTransformation(["occupation", "gender"], spark) transformed_df = dist_category_transformation.apply(user_df) @@ -124,8 +123,26 @@ def test_multiple_categories_transformation(user_df): assert gender_distinct_values == 3 +def test_multiple_single_cat_cols_json(user_df, spark): + """Test JSON representation when transforming multiple single-cat columns""" + dist_category_transformation = DistCategoryTransformation(["occupation", "gender"], spark) + + _ = dist_category_transformation.apply(user_df) + + multi_cols_rep = dist_category_transformation.get_json_representation() + + labels_array = multi_cols_rep["string_indexer_labels_array"] + one_hot_index_for_string = multi_cols_rep["per_col_label_to_one_hot_idx"] + cols = multi_cols_rep["cols"] + + # The Spark-generated and our own one-hot-index mappings should match + for col_labels, col in zip(labels_array, cols): + for idx, label in enumerate(col_labels): + assert idx == one_hot_index_for_string[col][label] + + def test_multi_category_transformation(multi_cat_df_and_separator, check_df_schema): - """Test transforming multi-category column""" + """Test transforming single multi-category column""" df, separator = multi_cat_df_and_separator col_name = df.columns[0] @@ -180,9 +197,10 @@ def test_multi_category_limited_categories(multi_cat_df_and_separator): def test_csv_input_categorical(spark: SparkSession, check_df_schema): + """Test categorical transformations with CSV input, as we treat them separately""" data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") long_vector_df = spark.read.csv(data_path, sep=",", header=True) - dist_categorical_transormation = DistCategoryTransformation(cols=["id"]) + dist_categorical_transormation = DistCategoryTransformation(cols=["id"], spark=spark) transformed_df = dist_categorical_transormation.apply(long_vector_df) check_df_schema(transformed_df) @@ -199,6 +217,7 @@ def test_csv_input_categorical(spark: SparkSession, check_df_schema): def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): + """Test mulit-categorical transformations with CSV input""" data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") long_vector_df = spark.read.csv(data_path, sep=",", header=True) dist_categorical_transormation = DistMultiCategoryTransformation(cols=["feat"], separator=";") @@ -207,13 +226,14 @@ def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): check_df_schema(transformed_df) transformed_rows = transformed_df.collect() expected_rows = [] - for i in range(5): + for _ in range(5): expected_rows.append([1] * 100) for row, expected_row in zip(transformed_rows, expected_rows): assert row["feat"] == expected_row def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): + """Test multi-categorical transformations with Parquet input""" # Define the schema for the DataFrame schema = StructType([StructField("names", ArrayType(StringType()), True)]) @@ -240,7 +260,7 @@ def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): # Show the DataFrame loaded from the Parquet file dist_categorical_transormation = DistMultiCategoryTransformation( - cols=["names"], separator=None + cols=["names"], separator="" ) transformed_df = dist_categorical_transormation.apply(df_parquet) diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index ae285d42fa..5cfa4988c4 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -28,6 +28,7 @@ from graphstorm_processing.graph_loaders.dist_heterogeneous_loader import ( DistHeterogeneousGraphLoader, + HeterogeneousLoaderConfig, NODE_MAPPING_INT, NODE_MAPPING_STR, ) @@ -46,6 +47,7 @@ MIN_VALUE, MAX_VALUE, VALUE_COUNTS, + TRANSFORMATIONS_FILENAME, ) pytestmark = pytest.mark.usefixtures("spark") @@ -99,16 +101,21 @@ def no_label_data_configs_fixture(): def dghl_loader_fixture(spark, data_configs_with_label, tempdir) -> DistHeterogeneousGraphLoader: """Create a re-usable loader that includes labels""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - dhgl = DistHeterogeneousGraphLoader( - spark, + loader_config = HeterogeneousLoaderConfig( + add_reverse_edges=True, + data_configs=data_configs_with_label, + enable_assertions=True, + graph_name="small_heterogeneous_graph", + input_prefix=input_path, local_input_path=input_path, local_output_path=tempdir, - output_prefix=tempdir, - input_prefix=input_path, - data_configs=data_configs_with_label, num_output_files=1, - add_reverse_edges=True, - enable_assertions=True, + output_prefix=tempdir, + precomputed_transformations={}, + ) + dhgl = DistHeterogeneousGraphLoader( + spark, + loader_config=loader_config, ) return dhgl @@ -119,16 +126,21 @@ def dghl_loader_no_label_fixture( ) -> DistHeterogeneousGraphLoader: """Create a re-usable loader without labels""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - dhgl = DistHeterogeneousGraphLoader( - spark, + loader_config = HeterogeneousLoaderConfig( + add_reverse_edges=True, + data_configs=no_label_data_configs, + enable_assertions=True, + graph_name="small_heterogeneous_graph", + input_prefix=input_path, local_input_path=input_path, local_output_path=tempdir, - output_prefix=tempdir, - input_prefix=input_path, - data_configs=no_label_data_configs, num_output_files=1, - add_reverse_edges=True, - enable_assertions=True, + output_prefix=tempdir, + precomputed_transformations={}, + ) + dhgl = DistHeterogeneousGraphLoader( + spark, + loader_config, ) return dhgl @@ -139,16 +151,21 @@ def dghl_loader_no_reverse_edges_fixture( ) -> DistHeterogeneousGraphLoader: """Create a re-usable loader that doesn't produce reverse edegs""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - dhgl = DistHeterogeneousGraphLoader( - spark, + loader_config = HeterogeneousLoaderConfig( + add_reverse_edges=False, + data_configs=data_configs_with_label, + enable_assertions=True, + graph_name="small_heterogeneous_graph", + input_prefix=input_path, local_input_path=input_path, local_output_path=tempdir, - output_prefix=tempdir, - input_prefix=input_path, - data_configs=data_configs_with_label, num_output_files=1, - add_reverse_edges=False, - enable_assertions=True, + output_prefix=tempdir, + precomputed_transformations={}, + ) + dhgl = DistHeterogeneousGraphLoader( + spark, + loader_config, ) return dhgl @@ -246,6 +263,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade "input_ids": 16, "token_type_ids": 16, "multi": 2, + "state": 3, } }, "efeat_size": {}, @@ -273,6 +291,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade "test_mask", "age", "multi", + "state", "input_ids", "attention_mask", "token_type_ids", @@ -282,6 +301,13 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade for node_type in metadata["node_data"]: assert metadata["node_data"][node_type].keys() == expected_node_data[node_type] + with open( + os.path.join(dghl_loader.output_path, TRANSFORMATIONS_FILENAME), "r", encoding="utf-8" + ) as transformation_file: + transformations_dict = json.load(transformation_file) + + assert "state" in transformations_dict["node_features"] + def test_load_dist_hgl_without_labels( dghl_loader_no_label: DistHeterogeneousGraphLoader,