diff --git a/graphstorm-processing/graphstorm_processing/constants.py b/graphstorm-processing/graphstorm_processing/constants.py index 13b7af65ce..976ab57818 100644 --- a/graphstorm-processing/graphstorm_processing/constants.py +++ b/graphstorm-processing/graphstorm_processing/constants.py @@ -85,3 +85,6 @@ class FilesystemType(Enum): "3.4": "3.3.4", "3.3": "3.3.2", } + +########## Precomputed transformations ################ +TRANSFORMATIONS_FILENAME = "precomputed_transformations.json" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py index 5527960b6a..9ef2fd0a33 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py @@ -16,7 +16,7 @@ import logging -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession from graphstorm_processing.config.feature_config_base import FeatureConfig from .dist_transformations import ( @@ -37,13 +37,19 @@ class DistFeatureTransformer(object): which can then be be applied through a call to apply_transformation. """ - def __init__(self, feature_config: FeatureConfig): + def __init__( + self, feature_config: FeatureConfig, spark: SparkSession, json_representation: dict + ): feat_type = feature_config.feat_type feat_name = feature_config.feat_name args_dict = feature_config.transformation_kwargs self.transformation: DistributedTransformation + # TODO: We will use this to re-apply transformations + self.json_representation = json_representation - default_kwargs = {"cols": feature_config.cols} + default_kwargs = { + "cols": feature_config.cols, + } logging.info("Feature name: %s", feat_name) logging.info("Transformation type: %s", feat_type) @@ -56,7 +62,9 @@ def __init__(self, feature_config: FeatureConfig): elif feat_type == "bucket-numerical": self.transformation = DistBucketNumericalTransformation(**default_kwargs, **args_dict) elif feat_type == "categorical": - self.transformation = DistCategoryTransformation(**default_kwargs, **args_dict) + self.transformation = DistCategoryTransformation( + **default_kwargs, **args_dict, spark=spark + ) elif feat_type == "multi-categorical": self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict) elif feat_type == "huggingface": @@ -66,14 +74,24 @@ def __init__(self, feature_config: FeatureConfig): f"Feature {feat_name} has type: {feat_type} that is not supported" ) - def apply_transformation(self, input_df: DataFrame) -> DataFrame: + def apply_transformation(self, input_df: DataFrame) -> tuple[DataFrame, dict]: """ - Given an input dataframe, select only the relevant columns + Given an input DataFrame, select only the relevant columns and apply the expected transformation to them. + + Returns + ------- + tuple[DataFrame, dict] + A tuple with two items, the first is the transformed input DataFrame, + the second is a JSON representation of the transformation. This will + allow us to apply the same transformation to new data. """ input_df = input_df.select(self.transformation.cols) # type: ignore - return self.transformation.apply(input_df) + return ( + self.transformation.apply(input_df), + self.transformation.get_json_representation(), + ) def get_transformation_name(self) -> str: """ diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py index 21b31e1154..221bdce0f7 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/base_dist_transformation.py @@ -15,9 +15,9 @@ """ from abc import ABC, abstractmethod -from typing import Sequence +from typing import Optional, Sequence -from pyspark.sql import DataFrame +from pyspark.sql import DataFrame, SparkSession class DistributedTransformation(ABC): @@ -25,19 +25,34 @@ class DistributedTransformation(ABC): Base class for all distributed transformations. """ - def __init__(self, cols: Sequence[str]) -> None: + def __init__( + self, + cols: Sequence[str], + spark: Optional[SparkSession] = None, + json_representation: Optional[dict] = None, + ) -> None: self.cols = cols + self.spark = spark + self.json_representation = json_representation @abstractmethod def apply(self, input_df: DataFrame) -> DataFrame: """ - Applies the transformation to the input DataFrame. - The returned dataframe will only contain the columns specified during initialization. + Applies the transformation to the input DataFrame, and returns the modified + DataFrame. + + The returned DataFrame will only contain the columns specified during initialization. """ + def get_json_representation(self) -> dict: + """Get a JSON representation of the transformation.""" + # TODO: Should we try to guarantee apply() has ran before this? + if self.json_representation: + return self.json_representation + else: + return {} + @staticmethod @abstractmethod def get_transformation_name() -> str: - """ - Get the name of the transformation - """ + """Get the name of the transformation.""" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py index b03a0c9333..8dfd5ac49f 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py @@ -14,11 +14,12 @@ limitations under the License. """ -from typing import Dict, List, Optional, Sequence +from typing import List, Optional, Sequence import numpy as np +import pandas as pd -from pyspark.sql import DataFrame, functions as F +from pyspark.sql import DataFrame, functions as F, SparkSession from pyspark.sql.functions import when from pyspark.sql.types import ArrayType, FloatType, StringType from pyspark.ml.feature import StringIndexer, OneHotEncoder @@ -40,8 +41,8 @@ class DistCategoryTransformation(DistributedTransformation): Transforms categorical features into a vector of one-hot-encoded values. """ - def __init__(self, cols: List[str]) -> None: - super().__init__(cols) + def __init__(self, cols: list[str], spark: SparkSession) -> None: + super().__init__(cols, spark) @staticmethod def get_transformation_name() -> str: @@ -49,9 +50,10 @@ def get_transformation_name() -> str: def apply(self, input_df: DataFrame) -> DataFrame: processed_col_names = [] - for col in self.cols: - processed_col_names.append(col + "_processed") - distinct_category_counts = input_df.groupBy(col).count() # type: DataFrame + top_categories_per_col: dict[str, list] = {} + for current_col in self.cols: + processed_col_names.append(current_col + "_processed") + distinct_category_counts = input_df.groupBy(current_col).count() # type: DataFrame num_distinct_categories = distinct_category_counts.count() # Conditionally replace rare categories with single placeholder @@ -60,17 +62,23 @@ def apply(self, input_df: DataFrame) -> DataFrame: MAX_CATEGORIES_PER_FEATURE - 1 ) top_categories_set = {row[0] for row in top_categories} + top_categories_per_col[current_col] = list(top_categories_set) # TODO: Ideally we don't want to use withColumn in a loop input_df = input_df.withColumn( - col, - when(input_df[col].isin(top_categories_set), input_df[col]).otherwise( - RARE_CATEGORY - ), + current_col, + when( + input_df[current_col].isin(top_categories_set), input_df[current_col] + ).otherwise(RARE_CATEGORY), ) + else: + top_categories_per_col[current_col] = [ + x[current_col] for x in distinct_category_counts.select(current_col).collect() + ] # Replace empty string cols with None input_df = input_df.withColumn( - col, when(input_df[col] == "", None).otherwise(input_df[col]) + current_col, + when(input_df[current_col] == "", None).otherwise(input_df[current_col]), ) # We first convert the strings to float indexes @@ -105,8 +113,80 @@ def apply(self, input_df: DataFrame) -> DataFrame: ] ) + # Structure: {column_name: {category_string: index_value, ...}. ...} + per_col_label_to_one_hot_idx: dict[str, dict[str, int]] = {} + + # To get the transformed values for each value in each col + # we need to create a DataFrame with the top categories for the current + # col, then fill in the rest of the values with placeholders + # and pass the generated DF through the one-hot encoder + for current_col, processed_col in zip(self.cols, processed_col_names): + other_cols = [x for x in self.cols if x != current_col] + top_str_categories_list = top_categories_per_col[current_col] + # Spark doesn't include missing/unknown values in the vector + # representation, just uses the all-zeroes vector for them, + # so we remove instances of None from the list of strings to model + if None in top_str_categories_list: + top_str_categories_list.remove(None) + + # Each col might have different number of top categories, we need one DF per col + num_current_col_cats = len(top_str_categories_list) + # We don't care about values for the other cols in this iteration, + # just fill with empty string + placeholder_vals = [""] * num_current_col_cats + placeholder_cols = [placeholder_vals for _ in range(len(self.cols) - 1)] + current_col_unique_vals = [list(top_str_categories_list)] + # We need to create a DF where all cols have num_rows == num_current_col_cats + # and the current col needs to be the first col in the DF. + vals_dict = dict( + zip([current_col] + other_cols, current_col_unique_vals + placeholder_cols) + ) + + # One hot encoder expects a DF with all cols that were used to train it + # so we use the top-MAX_CATEGORIES_PER_FEATURE values for the current col, + # and the placeholders for the rest + top_str_categories_df = self.spark.createDataFrame(pd.DataFrame(vals_dict)) + top_indexed_categories_df = str_indexer_model.transform(top_str_categories_df) + + # For the current col, get the one-hot index for each of its category strings + # by passing the top-k values DF through the one-hot encoder model + per_col_label_to_one_hot_idx[current_col] = { + x[current_col]: int(x[processed_col]) + for x in one_hot_encoder_model.transform(top_indexed_categories_df).collect() + } + + # see get_json_representation() docstring for structure + self.json_representation = { + "string_indexer_labels_array": str_indexer_model.labelsArray, + "cols": self.cols, + "per_col_label_to_one_hot_idx": per_col_label_to_one_hot_idx, + "transformation_name": self.get_transformation_name(), + } + return dense_vector_features + def get_json_representation(self) -> dict: + """Representation of the single-category transformation for one or more columns. + + Returns + ------- + dict + Structure: + string_indexer_labels_array: + tuple[tuple[str]], outer tuple has num_cols elements, + each inner tuple has num_cats elements, each str is a category string. + Spark uses this to represent the one-hot index for each category, its + position in the inner tuple is the one-hot-index position for the string. + Categories are sorted by their frequency in the data. + cols: + list[str], with num_cols elements + per_col_label_to_one_hot_idx: + dict[str, dict[str, int]], with num_cols elements, each with num_categories elements + transformation_name: + str, will be 'DistCategoryTransformation' + """ + return self.json_representation + class DistMultiCategoryTransformation(DistributedTransformation): """ @@ -135,7 +215,7 @@ def __init__(self, cols: Sequence[str], separator: str) -> None: if self.separator in SPECIAL_CHARACTERS: self.separator = f"\\{self.separator}" - self.value_map = {} # type: Dict[str, int] + self.value_map: dict[str, int] = {} @staticmethod def get_transformation_name() -> str: diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index d97b9ff4b0..ffd2ee92b1 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -48,24 +48,29 @@ the Spark leader. """ -import dataclasses import argparse +import copy +import dataclasses import json import logging import os from pathlib import Path +import tempfile import time +from collections.abc import Mapping from typing import Any, Dict -import tempfile import boto3 import botocore from graphstorm_processing.graph_loaders.dist_heterogeneous_loader import ( DistHeterogeneousGraphLoader, + HeterogeneousLoaderConfig, + ProcessedGraphRepresentation, ) from graphstorm_processing.config.config_parser import create_config_objects from graphstorm_processing.config.config_conversion import GConstructConfigConverter +from graphstorm_processing.constants import TRANSFORMATIONS_FILENAME from graphstorm_processing.data_transformations import spark_utils, s3_utils from graphstorm_processing.repartition_files import ( repartition_files, @@ -83,8 +88,8 @@ class ExecutorConfig: Parameters ---------- local_config_path : str - Local path to the config file. - local_output_path : str + Local path to the input config file. + local_metadata_output_path : str Local path for output metadata files. input_prefix : str Prefix for input data. Can be S3 URI or local path. @@ -107,7 +112,7 @@ class ExecutorConfig: """ local_config_path: str - local_output_path: str + local_metadata_output_path: str input_prefix: str output_prefix: str num_output_files: int @@ -148,7 +153,7 @@ def __init__( executor_config: ExecutorConfig, ): self.local_config_path = executor_config.local_config_path - self.local_output_path = executor_config.local_output_path + self.local_metadata_output_path = executor_config.local_metadata_output_path self.input_prefix = executor_config.input_prefix self.output_prefix = executor_config.output_prefix self.num_output_files = executor_config.num_output_files @@ -158,6 +163,8 @@ def __init__( self.add_reverse_edges = executor_config.add_reverse_edges self.graph_name = executor_config.graph_name self.repartition_on_leader = executor_config.do_repartition + # Input config dict using GSProcessing schema + self.gsp_config_dict = {} # Ensure we have write access to the output path if self.filesystem_type == FilesystemType.LOCAL: @@ -190,23 +197,33 @@ def __init__( with open(graph_conf, "r", encoding="utf-8") as f: dataset_config_dict: Dict[str, Any] = json.load(f) - # Use appropriate config parser depending on file version + # Load the pre-computed transformations if the file exists + if os.path.exists(os.path.join(self.local_config_path, TRANSFORMATIONS_FILENAME)): + with open( + os.path.join(self.local_config_path, TRANSFORMATIONS_FILENAME), + "r", + encoding="utf-8", + ) as f: + self.precomputed_transformations = json.load(f) + else: + self.precomputed_transformations = {} + if "version" in dataset_config_dict: config_version = dataset_config_dict["version"] if config_version == "gsprocessing-v1.0": logging.info("Parsing config file as GSProcessing config") - self.graph_config_dict = dataset_config_dict["graph"] + self.gsp_config_dict = dataset_config_dict["graph"] elif config_version == "gconstruct-v1.0": logging.info("Parsing config file as GConstruct config") converter = GConstructConfigConverter() - self.graph_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)[ + self.gsp_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)[ "graph" ] else: logging.warning("Unrecognized configuration file version name: %s", config_version) try: converter = GConstructConfigConverter() - self.graph_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)[ + self.gsp_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)[ "graph" ] except Exception: # pylint: disable=broad-exception-caught @@ -214,13 +231,13 @@ def __init__( assert ( "graph" in dataset_config_dict ), "Top-level element 'graph' needs to exist in a GSProcessing config" - self.graph_config_dict = dataset_config_dict["graph"] + self.gsp_config_dict = dataset_config_dict["graph"] logging.info("Parsed config file as GSProcessing config") else: # Older versions of GConstruct configs might be missing a version entry logging.warning("No configuration file version name, trying to parse as GConstruct...") converter = GConstructConfigConverter() - self.graph_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)["graph"] + self.gsp_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)["graph"] # Create the Spark session for execution self.spark = spark_utils.create_spark_session(self.execution_env, self.filesystem_type) @@ -256,23 +273,28 @@ def run(self) -> None: Executes the Spark processing job. """ logging.info("Performing data processing with PySpark...") - data_configs = create_config_objects(self.graph_config_dict) + data_configs = create_config_objects(self.gsp_config_dict) t0 = time.time() # Prefer explicit arguments for clarity - loader = DistHeterogeneousGraphLoader( - spark=self.spark, - local_input_path=self.local_config_path, - local_output_path=self.local_output_path, - data_configs=data_configs, - input_prefix=self.input_prefix, - output_prefix=self.output_prefix, - num_output_files=self.num_output_files, + loader_config = HeterogeneousLoaderConfig( add_reverse_edges=self.add_reverse_edges, + data_configs=data_configs, enable_assertions=False, graph_name=self.graph_name, + input_prefix=self.input_prefix, + local_input_path=self.local_config_path, + local_metadata_output_path=self.local_metadata_output_path, + num_output_files=self.num_output_files, + output_prefix=self.output_prefix, + precomputed_transformations=self.precomputed_transformations, + ) + loader = DistHeterogeneousGraphLoader( + self.spark, + loader_config, ) - graph_meta_dict, timers_dict = loader.load() + processed_representations: ProcessedGraphRepresentation = loader.load() + graph_meta_dict = processed_representations.processed_graph_metadata_dict t1 = time.time() logging.info("Time to transform data for distributed partitioning: %s sec", t1 - t0) @@ -316,7 +338,8 @@ def run(self) -> None: ) else: logging.warning("gs-repartition will need to run as a follow-up job on the data!") - timers_dict["repartition"] = time.perf_counter() - repartition_start + + processed_representations.timers["repartition"] = time.perf_counter() - repartition_start # If any of the metadata modification took place, write an updated metadata file if updated_metadata: @@ -334,18 +357,142 @@ def run(self) -> None: ) with open( - os.path.join(self.local_output_path, "perf_counters.json"), "w", encoding="utf-8" + os.path.join(self.local_metadata_output_path, "perf_counters.json"), + "w", + encoding="utf-8", ) as f: - sorted_timers = dict(sorted(timers_dict.items(), key=lambda x: x[1], reverse=True)) + sorted_timers = dict( + sorted(processed_representations.timers.items(), key=lambda x: x[1], reverse=True) + ) json.dump(sorted_timers, f, indent=4) - # This is used to upload the output JSON files to S3 on non-SageMaker runs, + # If pre-computed representations exist, merge them with the input dict and save to disk + with open( + os.path.join( + self.local_metadata_output_path, + f"{os.path.splitext(self.config_filename)[0]}_with_transformations.json", + ), + "w", + encoding="utf-8", + ) as f: + input_dict_with_transforms = self._merge_config_with_transformations( + self.gsp_config_dict, processed_representations.transformation_representations + ) + json.dump(input_dict_with_transforms, f, indent=4) + + # This is used to upload the output output JSON files to S3 on non-SageMaker runs, # since we can't rely on SageMaker to do it if self.filesystem_type == FilesystemType.S3: self._upload_output_files( - loader, force=not self.execution_env == ExecutionEnv.SAGEMAKER + loader, force=(not self.execution_env == ExecutionEnv.SAGEMAKER) ) + def _merge_config_with_transformations( + self, + gsp_config_dict: dict, + transformations: Mapping[str, Mapping[str, Mapping]], + ) -> dict: + """Merge the config dict with the transformations dict and return a copy. + + Parameters + ---------- + gsp_config_dict : dict + The input configuration dictionary, using GSProcessing schema + transformations : Mapping[str, Mapping[str, Mapping]] + The processed graph representations containing the transformations. + Expected dict schema: + + { + "node_features": { + "node_type1": { + "feature_name1": { + "transformation": # transformation type + # feature1 representation goes here + }, + "feature_name2": {}, ... + }, + "node_type2": {...} + }, + "edges_features": {...} + } + + Returns + ------- + dict + A copy of the ``gsp_config_dict`` modified to be merged with + the transformation representations. + """ + gsp_config_dict_copy = copy.deepcopy(gsp_config_dict) + + edge_transformations: Mapping[str, Mapping[str, Mapping]] = transformations["edge_features"] + node_transformations: Mapping[str, Mapping[str, Mapping]] = transformations["node_features"] + edge_input_dicts: list[dict] = gsp_config_dict_copy["edges"] + node_input_dicts: list[dict] = gsp_config_dict_copy["nodes"] + + def get_structure_type(single_input_dict: dict, structure_type: str) -> str: + """Gets the node or edge type name from input dict.""" + if structure_type == "node": + type_name = single_input_dict["type"] + elif structure_type == "edge": + src_type = single_input_dict["source"]["type"] + dst_type = single_input_dict["dest"]["type"] + relation = single_input_dict["relation"]["type"] + type_name = f"{src_type}:{relation}:{dst_type}" + else: + raise ValueError( + f"Invalid structure type: {structure_type}. Expected 'node' or 'edge'." + ) + + return type_name + + def append_transformations( + structure_input_dicts: list[dict], + structure_transforms: Mapping[str, Mapping[str, Mapping]], + structure_type: str, + ): + """Appends the pre-computed transformations to the input dicts.""" + assert structure_type in ["edge", "node"] + for input_dict in structure_input_dicts: + # type_name is the name of either a node type or edge type + type_name = get_structure_type(input_dict, structure_type) + # If we have pre-computed transformations for this type + if type_name in structure_transforms: + # type_transforms holds the transformation representations for + # every feature that has one for type_name, from feature name to + # feature representation dict. + type_transforms: Mapping[str, Mapping] = structure_transforms[type_name] + assert ( + "features" in input_dict + ), f"Expected type {type_name} to have have features in the input config" + + # Iterate over every feature for the node/edge type, + # and append representation to its input dict, if one exists + for type_feat_dict in input_dict["features"]: + # We take a feature's name either explicitly if it exists, + # or from the column name otherwise. + feat_name = ( + type_feat_dict["name"] + if "name" in type_feat_dict + else type_feat_dict["column"] + ) + if feat_name in type_transforms: + # Feature representation needs to contain all the + # necessary information to re-apply the feature transformation + feature_representation = type_transforms[feat_name] + type_feat_dict["precomputed_transformation"] = feature_representation + + if edge_transformations: + append_transformations(edge_input_dicts, edge_transformations, "edge") + if node_transformations: + append_transformations(node_input_dicts, node_transformations, "node") + + gsp_top_level_dict = { + "graph": gsp_config_dict_copy, + "version": "gsprocessing-v1.0", + } + + return gsp_top_level_dict + def parse_args() -> argparse.Namespace: """Parse the arguments for the execution.""" @@ -474,29 +621,40 @@ def main(): f"s3://{input_bucket}/{input_s3_prefix}/" f"{gsprocessing_args.config_filename}" ) from e + # Try to download the pre-computed transformations file, if it exists + try: + s3.download_file( + input_bucket, + f"{input_s3_prefix}/{TRANSFORMATIONS_FILENAME}", + os.path.join(tempdir.name, TRANSFORMATIONS_FILENAME), + ) + except botocore.exceptions.ClientError as _: + pass local_config_path = tempdir.name # local output location for metadata files if execution_env == ExecutionEnv.SAGEMAKER: - local_output_path = "/opt/ml/processing/output" + local_metadata_output_path = "/opt/ml/processing/output" else: if filesystem_type == FilesystemType.LOCAL: - local_output_path = gsprocessing_args.output_prefix + local_metadata_output_path = gsprocessing_args.output_prefix else: # Only needed for local execution with S3 data - local_output_path = tempdir.name + local_metadata_output_path = tempdir.name if not gsprocessing_args.num_output_files: gsprocessing_args.num_output_files = -1 # Save arguments to file for posterity - with open(os.path.join(local_output_path, "launch_arguments.json"), "w", encoding="utf-8") as f: + with open( + os.path.join(local_metadata_output_path, "launch_arguments.json"), "w", encoding="utf-8" + ) as f: json.dump(dataclasses.asdict(gsprocessing_args), f, indent=4) f.flush() executor_configuration = ExecutorConfig( local_config_path=local_config_path, - local_output_path=local_output_path, + local_metadata_output_path=local_metadata_output_path, input_prefix=gsprocessing_args.input_prefix, output_prefix=gsprocessing_args.output_prefix, num_output_files=gsprocessing_args.num_output_files, diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 79fc14d454..e0a1832c22 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -20,8 +20,10 @@ import numbers import os from collections import Counter, defaultdict +from collections.abc import Mapping, Sequence +from dataclasses import dataclass from time import perf_counter -from typing import Any, Dict, Mapping, Optional, Sequence, Set, Tuple +from typing import Any, Dict, Optional, Set, Tuple from pyspark import RDD from pyspark.sql import Row, SparkSession, DataFrame, functions as F @@ -45,6 +47,7 @@ SPECIAL_CHARACTERS, HUGGINGFACE_TRANFORM, HUGGINGFACE_TOKENIZE, + TRANSFORMATIONS_FILENAME, ) from ..config.config_parser import EdgeConfig, NodeConfig, StructureConfig from ..config.label_config_base import LabelConfig @@ -52,7 +55,6 @@ from ..data_transformations.dist_feature_transformer import DistFeatureTransformer from ..data_transformations.dist_label_loader import DistLabelLoader, SplitRates, CustomSplit from ..data_transformations import s3_utils, spark_utils -from .heterogeneous_graphloader import HeterogeneousGraphLoader # TODO: Remove the pylint disable once we add the rest of the code from . import schema_utils # pylint: disable=no-name-in-module @@ -64,88 +66,155 @@ NODE_MAPPING_INT = "new" -class DistHeterogeneousGraphLoader(HeterogeneousGraphLoader): +@dataclass +class HeterogeneousLoaderConfig: """ - A graph loader designed to run distributed processing of a heterogeneous graph. + Configuration object for the loader. - Parameters - ---------- - spark : SparkSession - The SparkSession that we use to perform the processing - local_input_path : str - Local path to input configuration data - local_metadata_path : str - Local path to where the output metadata files will be created. + add_reverse_edges : bool + Whether to add reverse edges to the graph. data_configs : Dict[str, Sequence[StructureConfig]] Dictionary of node and edge configurations objects. + enable_assertions : bool, optional + When true enables sanity checks for the output created. + However these are costly to compute, so we disable them by default. + graph_name: str + Name of the graph. input_prefix : str The prefix to the input data. Can be an S3 URI or an **absolute** local path. + local_input_path : str + Local path to input configuration data + local_metadata_output_path : str + Local path to where the output metadata files will be created. + num_output_files : int + The number of files (partitions) to create for the output, if None we + let Spark decide. output_prefix : str The prefix to where the output data will be created. Can be an S3 URI or an **absolute** local path. - num_output_files : Optional[int], optional - The number of files (partitions) to create for the output, if None we - let Spark decide. - enable_assertions : bool, optional - When true enables sanity checks for the output created. - However these are costly to compute, so we disable them by default. + precomputed_transformations: dict + A dictionary describing precomputed transformations for the features + of the graph. + """ + + add_reverse_edges: bool + data_configs: Mapping[str, Sequence[StructureConfig]] + enable_assertions: bool + graph_name: str + input_prefix: str + local_input_path: str + local_metadata_output_path: str + num_output_files: int + output_prefix: str + precomputed_transformations: dict + + +@dataclass +class ProcessedGraphRepresentation: + """JSON representations of a processed graph. + + Parameters + ---------- + processed_graph_metadata_dict : dict + A dictionary of metadata for the graph, in "chunked-graph" + format, with additional key "graph_info" that contains a more + verbose representation of th processed graph. + + The dict also contains a "raw_id_mappings" key, which is a dict + of dicts, one for each node type. Each entry contains files information + about the raw-to-integer ID mapping for each node. + + The returned value also contains an additional dict of dicts, + "graph_info" which contains additional information about the + graph in a more readable format. + + For chunked graph format see + https://docs.dgl.ai/guide/distributed-preprocessing.html#specification + transformation_representations : dict + A dictionary containing the processed graph transformations. + timers : dict + A dictionary containing the timing information for different steps of processing. + """ + + processed_graph_metadata_dict: dict + transformation_representations: dict + timers: dict + + +class DistHeterogeneousGraphLoader(object): + """ + A graph loader designed to run distributed processing of a heterogeneous graph. + + Parameters + ---------- + spark : SparkSession + The SparkSession that we use to perform the processing """ def __init__( self, spark: SparkSession, - local_input_path: str, - local_output_path: str, - data_configs: Dict[str, Sequence[StructureConfig]], - input_prefix: str, - output_prefix: str, - num_output_files: Optional[int] = None, - add_reverse_edges=True, - enable_assertions=False, - graph_name: Optional[str] = None, + loader_config: HeterogeneousLoaderConfig, ): - super().__init__(local_input_path, local_output_path, data_configs) + self.output_path = loader_config.local_metadata_output_path + self._data_configs = loader_config.data_configs + self.feature_configs: list[FeatureConfig] = [] # TODO: Pass as an argument? - if input_prefix.startswith("s3://"): + if loader_config.input_prefix.startswith("s3://"): self.filesystem_type = "s3" else: - assert os.path.isabs(input_prefix), "We expect an absolute path" + assert os.path.isabs(loader_config.input_prefix), "We expect an absolute path" self.filesystem_type = "local" self.spark = spark # type: SparkSession - self.add_reverse_edges = add_reverse_edges + self.add_reverse_edges = loader_config.add_reverse_edges # Remove trailing slash in s3 paths if self.filesystem_type == "s3": - self.input_prefix = s3_utils.s3_path_remove_trailing(input_prefix) - self.output_prefix = s3_utils.s3_path_remove_trailing(output_prefix) + self.input_prefix = s3_utils.s3_path_remove_trailing(loader_config.input_prefix) + self.output_prefix = s3_utils.s3_path_remove_trailing(loader_config.output_prefix) else: # TODO: Any checks for local paths? - self.input_prefix = input_prefix - self.output_prefix = output_prefix + self.input_prefix = loader_config.input_prefix + self.output_prefix = loader_config.output_prefix self.num_output_files = ( - num_output_files - if num_output_files and num_output_files > 0 + loader_config.num_output_files + if loader_config.num_output_files and loader_config.num_output_files > 0 else int(spark.sparkContext.defaultParallelism) ) assert self.num_output_files > 0 # Mapping from node type to filepath, each file is a node-str to node-int-id mapping - self.node_mapping_paths = {} # type: Dict[str, Sequence[str]] + self.node_mapping_paths: dict[str, Sequence[str]] = {} # Mapping from label name to value counts self.label_properties: Dict[str, Counter] = defaultdict(Counter) - self.timers = defaultdict(float) # type: Dict - self.enable_assertions = enable_assertions + self.timers = defaultdict(float) + self.enable_assertions = loader_config.enable_assertions # Column names that are valid in CSV can be invalid in Parquet # we collect an column name substitutions we had to make # here and later output as a JSON file. - self.column_substitutions = {} # type: Dict[str, str] - self.graph_info = {} # type: Dict[str, Any] - self.graph_name = graph_name + self.column_substitutions: dict[str, str] = {} + self.graph_info: dict[str, Any] = {} + # Structure: + # { + # "node_features": { + # "node_type1": { + # "feature_name1": { + # # feature1 representation goes here + # }, + # "feature_name2": {}, ... + # }, + # "node_type2": {...} + # }, + # "edges_features": {...} + # } + self.transformation_representations = {"node_features": {}, "edge_features": {}} + self.graph_name = loader_config.graph_name self.skip_train_masks = False + self.pre_computed_transformations = loader_config.precomputed_transformations def process_and_write_graph_data( self, data_configs: Mapping[str, Sequence[StructureConfig]] - ) -> tuple[dict, dict]: + ) -> ProcessedGraphRepresentation: """Process and encode all graph data. Extracts and encodes graph structure before writing to storage, then applies pre-processing @@ -162,22 +231,12 @@ def process_and_write_graph_data( Returns ------- - tuple[dict, dict] - A tuple with two dictionaries: - The first is the dictionary of metadata for the graph, in "chunked-graph" - format, with additional keys. - For chunked graph format see - https://docs.dgl.ai/guide/distributed-preprocessing.html#specification - - The dict also contains a "raw_id_mappings" key, which is a dict - of dicts, one for each node type. Each entry contains files information - about the raw-to-integer ID mapping for each node. - - The returned value also contains an additional dict of dicts, - "graph_info" which contains additional information about the - graph in a more readable format. - - The second is a dict of duration values for each part of the execution. + ProcessedGraphRepresentation + A dataclass object containing JSON representations of a graph after + it has been processed. + + See `dist_heterogeneous_loader.ProcessedGraphRepresentation` for more + details. """ # TODO: See if it's better to return some data structure # for the followup steps instead of just have side-effects @@ -250,6 +309,14 @@ def process_and_write_graph_data( with open(os.path.join(self.output_path, "metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata_dict, f, indent=4) + # Write the transformations file + with open( + os.path.join(self.output_path, TRANSFORMATIONS_FILENAME), "w", encoding="utf-8" + ) as f: + json.dump(self.transformation_representations, f, indent=4) + + # Column substitutions contain any col names we needed to change because their original + # name did not fit Parquet requirements if len(self.column_substitutions) > 0: with open( os.path.join(self.output_path, "column_substitutions.json"), "w", encoding="utf-8" @@ -260,7 +327,13 @@ def process_and_write_graph_data( logging.info("Finished Distributed Graph Processing ...") - return metadata_dict, self.timers + processed_representations = ProcessedGraphRepresentation( + processed_graph_metadata_dict=metadata_dict, + transformation_representations=self.transformation_representations, + timers=self.timers, + ) + + return processed_representations @staticmethod def _at_least_one_label_exists(data_configs: Mapping[str, Sequence[StructureConfig]]) -> bool: @@ -462,7 +535,7 @@ def csv_row(data: Row): # Single column, but could be a multi-valued vector return f"{row_vals[0]}" - input_rdd = input_df.rdd.map(csv_row) # type: RDD + input_rdd: RDD = input_df.rdd.map(csv_row) input_rdd.saveAsTextFile(os.path.join(full_output_path, "csv")) prefix_with_format = os.path.join(output_prefix, "csv") else: @@ -949,7 +1022,7 @@ def _process_node_features( Returns ------- Tuple[Dict, Dict] - A tuple with two dicts, both dicts have names as their keys. + A tuple with two dicts, both dicts have feature names as their keys. The first dict has the output metadata for the feature as values, and the second has the lengths of the vector representations of the features as values. @@ -957,12 +1030,24 @@ def _process_node_features( node_type_feature_metadata = {} ntype_feat_sizes = {} # type: Dict[str, int] for feat_conf in feature_configs: - transformer = DistFeatureTransformer(feat_conf) + json_representation = self.transformation_representations.get(feat_conf.feat_name, {}) + transformer = DistFeatureTransformer(feat_conf, self.spark, json_representation) - transformed_feature_df = transformer.apply_transformation(nodes_df) + transformed_feature_df, json_representation = transformer.apply_transformation(nodes_df) transformed_feature_df.cache() - def write_processed_feature(feat_name, single_feature_df, node_type, transformer_name): + # Will evaluate False for empty dict + if json_representation: + self.transformation_representations["node_features"][ + feat_conf.feat_name + ] = json_representation + + def write_processed_feature( + feat_name: str, + single_feature_df: DataFrame, + node_type: str, + transformer: DistFeatureTransformer, + ): feature_output_path = os.path.join( self.output_prefix, f"node_data/{node_type}-{feat_name}" ) @@ -983,7 +1068,7 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer nfeat_size = 1 if isinstance(feat_val, (int, float)) else len(feat_val) ntype_feat_sizes.update({feat_name: nfeat_size}) - self.timers[f"{transformer_name}-{node_type}-{feat_name}"] = ( + self.timers[f"{transformer.get_transformation_name()}-{node_type}-{feat_name}"] = ( perf_counter() - node_transformation_start ) @@ -1001,7 +1086,7 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer bert_feat_name, single_feature_df, node_type, - transformer.get_transformation_name(), + transformer, ) else: single_feature_df = transformed_feature_df.select(feat_col).withColumnRenamed( @@ -1011,7 +1096,7 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer feat_name, single_feature_df, node_type, - transformer.get_transformation_name(), + transformer, ) # Unpersist and move on to next feature @@ -1416,10 +1501,16 @@ def _process_edge_features( feat_conf.feat_name, edge_type, ) - transformer = DistFeatureTransformer(feat_conf) + json_representation = self.transformation_representations.get(feat_conf.feat_name, {}) + transformer = DistFeatureTransformer(feat_conf, self.spark, json_representation) - transformed_feature_df = transformer.apply_transformation(edges_df) + transformed_feature_df, json_representation = transformer.apply_transformation(edges_df) transformed_feature_df.cache() + # Will evaluate False for empty dict + if json_representation: + self.transformation_representations["node_features"][ + feat_conf.feat_name + ] = json_representation def write_feature(self, feat_name, single_feature_df, edge_type, transformer_name): feature_output_path = os.path.join( @@ -1610,7 +1701,7 @@ def _update_label_properties( In case an invalid task type name is specified in the label config. """ label_col = label_config.label_column - if not node_or_edge_type in self.label_properties: + if node_or_edge_type not in self.label_properties: self.label_properties[node_or_edge_type] = Counter() # TODO: Something wrong with the assignment here? Investigate self.label_properties[node_or_edge_type][COLUMN_NAME] = label_col @@ -1876,5 +1967,6 @@ def process_custom_mask_df(input_df, split_file, mask_type): ) return train_mask_df, val_mask_df, test_mask_df - def load(self) -> tuple[dict, dict]: + def load(self) -> ProcessedGraphRepresentation: + """Load graph and return JSON representations.""" return self.process_and_write_graph_data(self._data_configs) diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/heterogeneous_graphloader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/heterogeneous_graphloader.py deleted file mode 100644 index 330a592f6d..0000000000 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/heterogeneous_graphloader.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"). -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -from .loader_base import GraphLoader - - -class HeterogeneousGraphLoader(GraphLoader): - """DGL Heterogeneous Graph Loader - - Parameters - ---------- - data_path: str - Path to input. - local_metadata_path: str - Local metadata files output path. For SageMaker it is from os.environ['SM_MODEL_DIR']. - data_config: dict - Config for loading raw data. - """ - - def __init__(self, data_path, local_metadata_path, data_configs): - super().__init__( - data_path=data_path, - local_metadata_path=local_metadata_path, - data_configs=data_configs, - ) diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/loader_base.py b/graphstorm-processing/graphstorm_processing/graph_loaders/loader_base.py deleted file mode 100644 index ebdb2067c0..0000000000 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/loader_base.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"). -You may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -Base class for graph data processing. -""" - -from abc import ABC -from typing import Dict, List -import os -import abc - -from graphstorm_processing.config.config_parser import StructureConfig -from graphstorm_processing.config.feature_config_base import FeatureConfig - - -class GraphLoader(ABC): - """Graph Loader base class - - Parameters - ---------- - data_path : str - Local path to input configuration file. - local_metadata_path : str - Output path for local metadata files. - data_configs : Dict[str, List[StructureConfig]] - Dictionary of graph structure configurations. - """ - - def __init__( - self, - data_path: str, - local_metadata_path: str, - data_configs: Dict[str, List[StructureConfig]], - ): - self._data_path = data_path - self._output_path = local_metadata_path - self._data_configs = data_configs - self._feats: List[FeatureConfig] = [] - - if not os.path.exists(local_metadata_path) and not local_metadata_path.startswith("s3://"): - os.makedirs(local_metadata_path) - - @abc.abstractmethod - def load(self) -> None: - """ - Performs the graph loading, reading input files from storage, processing graph - structure and node/edge features and writes processed output to storage. - """ - - @property - def output_path(self): - """ - Returns output path for local metadata files. - """ - return self._output_path - - @property - def features(self) -> List[FeatureConfig]: - """ - Returns list of feature configurations. - """ - return self._feats diff --git a/graphstorm-processing/tests/conftest.py b/graphstorm-processing/tests/conftest.py index f34f1418f2..3be8cd2f68 100644 --- a/graphstorm-processing/tests/conftest.py +++ b/graphstorm-processing/tests/conftest.py @@ -18,6 +18,7 @@ import sys import logging import tempfile +from typing import Iterator import numpy as np import pytest @@ -61,7 +62,7 @@ def temp_output_root(): @pytest.fixture(scope="session", name="spark") -def spark_fixture(): +def spark_fixture() -> Iterator[SparkSession]: """Create the main SparkContext we use throughout the tests""" spark_context = ( SparkSession.builder.master("local[4]") diff --git a/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json b/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json index 7967f5089b..56212b540b 100644 --- a/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json +++ b/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json @@ -57,16 +57,22 @@ } } }, - { + { "column": "occupation", "transformation": { "name": "huggingface", "kwargs": { "action": "tokenize_hf", "hf_model": "bert-base-uncased", - "max_seq_length":16 + "max_seq_length": 16 } } + }, + { + "column": "state", + "transformation": { + "name": "categorical" + } } ], "labels": [ diff --git a/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv b/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv index 5f958fa455..4bf7865c3d 100644 --- a/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv +++ b/graphstorm-processing/tests/resources/small_heterogeneous_graph/nodes/user.csv @@ -1,6 +1,6 @@ -~id,age,occupation,gender,salary,multi -mark,30,actor,male,10000,1|2 -john,22,student,male,,3|4 -tara,33,lawyer,female,30000,5|6 -kate,29,doctor,,35000,7|8 -george,22,student,male,0,9|10 \ No newline at end of file +~id,age,occupation,gender,salary,multi,state +mark,30,actor,male,10000,1|2,wa +john,22,student,male,,3|4,ca +tara,33,lawyer,female,30000,5|6,wa +kate,29,doctor,,35000,7|8, +george,22,student,male,0,9|10,ny \ No newline at end of file diff --git a/graphstorm-processing/tests/test_dist_category_transformation.py b/graphstorm-processing/tests/test_dist_category_transformation.py index 51a5662013..14b1ff99c5 100644 --- a/graphstorm-processing/tests/test_dist_category_transformation.py +++ b/graphstorm-processing/tests/test_dist_category_transformation.py @@ -16,12 +16,11 @@ from typing import Tuple, Iterator import os -import pytest -import pandas as pd import tempfile import mock from numpy.testing import assert_array_equal +import pytest from pyspark.sql import SparkSession, DataFrame from pyspark.sql.types import StructField, StructType, StringType, ArrayType @@ -40,7 +39,7 @@ def multi_cat_df_and_separator_fixture( spark: SparkSession, separator="," ) -> Iterator[Tuple[DataFrame, str]]: - """Gneerate multi-category df, yields the DF and its separator""" + """Generate multi-category df, yields the DF and its separator""" data = [ (f"Actor{separator}Director",), (f"Director{separator}Writer",), @@ -63,9 +62,9 @@ def multi_cat_df_and_separator_fixture( ), 3, ) -def test_limited_category_transformation(user_df): +def test_limited_category_transformation(user_df: DataFrame, spark: SparkSession): """Test single-cat transformation with limited categories""" - dist_category_transformation = DistCategoryTransformation(["occupation"]) + dist_category_transformation = DistCategoryTransformation(["occupation"], spark) transformed_df = dist_category_transformation.apply(user_df) group_counts = ( @@ -77,9 +76,9 @@ def test_limited_category_transformation(user_df): assert row["count"] == expected_count -def test_all_categories_transformation(user_df, check_df_schema): +def test_all_categories_transformation(user_df, check_df_schema, spark): """Test single-cat transformation with all categories""" - dist_category_transformation = DistCategoryTransformation(["occupation"]) + dist_category_transformation = DistCategoryTransformation(["occupation"], spark) transformed_df = dist_category_transformation.apply(user_df) @@ -102,7 +101,7 @@ def test_category_transformation_with_null_values(spark: SparkSession): columns = ["name", "occupation", "gender"] df = spark.createDataFrame(data, schema=columns) - dist_category_transformation = DistCategoryTransformation(["occupation"]) + dist_category_transformation = DistCategoryTransformation(["occupation"], spark) transformed_df = dist_category_transformation.apply(df) @@ -111,9 +110,9 @@ def test_category_transformation_with_null_values(spark: SparkSession): assert transformed_distinct_values == 4 -def test_multiple_categories_transformation(user_df): - """Test transforming multiple cat columns""" - dist_category_transformation = DistCategoryTransformation(["occupation", "gender"]) +def test_multiple_categories_transformation(user_df, spark): + """Test transforming multiple single-cat columns""" + dist_category_transformation = DistCategoryTransformation(["occupation", "gender"], spark) transformed_df = dist_category_transformation.apply(user_df) @@ -124,8 +123,29 @@ def test_multiple_categories_transformation(user_df): assert gender_distinct_values == 3 +def test_multiple_single_cat_cols_json(user_df, spark): + """Test JSON representation when transforming multiple single-cat columns""" + dist_category_transformation = DistCategoryTransformation(["occupation", "gender"], spark) + + _ = dist_category_transformation.apply(user_df) + + multi_cols_rep = dist_category_transformation.get_json_representation() + + labels_array = multi_cols_rep["string_indexer_labels_array"] + one_hot_index_for_string = multi_cols_rep["per_col_label_to_one_hot_idx"] + cols = multi_cols_rep["cols"] + name = multi_cols_rep["transformation_name"] + + assert name == "DistCategoryTransformation" + + # The Spark-generated and our own one-hot-index mappings should match + for col_labels, col in zip(labels_array, cols): + for idx, label in enumerate(col_labels): + assert idx == one_hot_index_for_string[col][label] + + def test_multi_category_transformation(multi_cat_df_and_separator, check_df_schema): - """Test transforming multi-category column""" + """Test transforming single multi-category column""" df, separator = multi_cat_df_and_separator col_name = df.columns[0] @@ -180,9 +200,10 @@ def test_multi_category_limited_categories(multi_cat_df_and_separator): def test_csv_input_categorical(spark: SparkSession, check_df_schema): + """Test categorical transformations with CSV input, as we treat them separately""" data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") long_vector_df = spark.read.csv(data_path, sep=",", header=True) - dist_categorical_transormation = DistCategoryTransformation(cols=["id"]) + dist_categorical_transormation = DistCategoryTransformation(cols=["id"], spark=spark) transformed_df = dist_categorical_transormation.apply(long_vector_df) check_df_schema(transformed_df) @@ -199,6 +220,7 @@ def test_csv_input_categorical(spark: SparkSession, check_df_schema): def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): + """Test mulit-categorical transformations with CSV input""" data_path = os.path.join(_ROOT, "resources/multi_num_numerical/multi_num.csv") long_vector_df = spark.read.csv(data_path, sep=",", header=True) dist_categorical_transormation = DistMultiCategoryTransformation(cols=["feat"], separator=";") @@ -207,13 +229,14 @@ def test_csv_input_multi_categorical(spark: SparkSession, check_df_schema): check_df_schema(transformed_df) transformed_rows = transformed_df.collect() expected_rows = [] - for i in range(5): + for _ in range(5): expected_rows.append([1] * 100) for row, expected_row in zip(transformed_rows, expected_rows): assert row["feat"] == expected_row def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): + """Test multi-categorical transformations with Parquet input""" # Define the schema for the DataFrame schema = StructType([StructField("names", ArrayType(StringType()), True)]) @@ -240,7 +263,7 @@ def test_parquet_input_multi_categorical(spark: SparkSession, check_df_schema): # Show the DataFrame loaded from the Parquet file dist_categorical_transormation = DistMultiCategoryTransformation( - cols=["names"], separator=None + cols=["names"], separator="" ) transformed_df = dist_categorical_transormation.apply(df_parquet) diff --git a/graphstorm-processing/tests/test_dist_executor.py b/graphstorm-processing/tests/test_dist_executor.py new file mode 100644 index 0000000000..f99e087491 --- /dev/null +++ b/graphstorm-processing/tests/test_dist_executor.py @@ -0,0 +1,80 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import shutil +import tempfile + +import pytest + +from graphstorm_processing.distributed_executor import DistributedExecutor, ExecutorConfig +from graphstorm_processing.constants import ExecutionEnv, FilesystemType + +_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.fixture(autouse=True, name="tempdir") +def tempdir_fixture(): + """Create temp dir for output files""" + tempdirectory = tempfile.mkdtemp( + prefix=os.path.join(_ROOT, "resources/test_output/"), + ) + yield tempdirectory + shutil.rmtree(tempdirectory) + + +def test_merge_input_and_transform_dicts(tempdir: str): + """Test run function with local data""" + input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") + executor_configuration = ExecutorConfig( + local_config_path=input_path, + local_metadata_output_path=tempdir, + input_prefix=input_path, + output_prefix=tempdir, + num_output_files=-1, + config_filename="gsprocessing-config.json", + execution_env=ExecutionEnv.LOCAL, + filesystem_type=FilesystemType.LOCAL, + add_reverse_edges=True, + graph_name="small_heterogeneous_graph", + do_repartition=True, + ) + + dist_executor = DistributedExecutor(executor_configuration) + + pre_comp_transormations = { + "node_features": { + "user": { + "state": { + "transformation_name": "categorical", + } + } + }, + "edge_features": {}, + } + + input_config_with_transforms = dist_executor._merge_config_with_transformations( + dist_executor.gsp_config_dict, + pre_comp_transormations, + ) + + # Ensure the "user" node type's "age" feature includes a transformation entry + for node_input_dict in input_config_with_transforms["graph"]["nodes"]: + if "user" == node_input_dict["type"]: + for feature in node_input_dict["features"]: + if "state" == feature["column"]: + transform_for_feature = feature["precomputed_transformation"] + assert transform_for_feature["transformation_name"] == "categorical" diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index ae285d42fa..068aed50bb 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -28,6 +28,7 @@ from graphstorm_processing.graph_loaders.dist_heterogeneous_loader import ( DistHeterogeneousGraphLoader, + HeterogeneousLoaderConfig, NODE_MAPPING_INT, NODE_MAPPING_STR, ) @@ -46,6 +47,7 @@ MIN_VALUE, MAX_VALUE, VALUE_COUNTS, + TRANSFORMATIONS_FILENAME, ) pytestmark = pytest.mark.usefixtures("spark") @@ -99,16 +101,21 @@ def no_label_data_configs_fixture(): def dghl_loader_fixture(spark, data_configs_with_label, tempdir) -> DistHeterogeneousGraphLoader: """Create a re-usable loader that includes labels""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - dhgl = DistHeterogeneousGraphLoader( - spark, - local_input_path=input_path, - local_output_path=tempdir, - output_prefix=tempdir, - input_prefix=input_path, - data_configs=data_configs_with_label, - num_output_files=1, + loader_config = HeterogeneousLoaderConfig( add_reverse_edges=True, + data_configs=data_configs_with_label, enable_assertions=True, + graph_name="small_heterogeneous_graph", + input_prefix=input_path, + local_input_path=input_path, + local_metadata_output_path=tempdir, + num_output_files=1, + output_prefix=tempdir, + precomputed_transformations={}, + ) + dhgl = DistHeterogeneousGraphLoader( + spark, + loader_config=loader_config, ) return dhgl @@ -119,16 +126,21 @@ def dghl_loader_no_label_fixture( ) -> DistHeterogeneousGraphLoader: """Create a re-usable loader without labels""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - dhgl = DistHeterogeneousGraphLoader( - spark, - local_input_path=input_path, - local_output_path=tempdir, - output_prefix=tempdir, - input_prefix=input_path, - data_configs=no_label_data_configs, - num_output_files=1, + loader_config = HeterogeneousLoaderConfig( add_reverse_edges=True, + data_configs=no_label_data_configs, enable_assertions=True, + graph_name="small_heterogeneous_graph", + input_prefix=input_path, + local_input_path=input_path, + local_metadata_output_path=tempdir, + num_output_files=1, + output_prefix=tempdir, + precomputed_transformations={}, + ) + dhgl = DistHeterogeneousGraphLoader( + spark, + loader_config, ) return dhgl @@ -139,16 +151,21 @@ def dghl_loader_no_reverse_edges_fixture( ) -> DistHeterogeneousGraphLoader: """Create a re-usable loader that doesn't produce reverse edegs""" input_path = os.path.join(_ROOT, "resources/small_heterogeneous_graph") - dhgl = DistHeterogeneousGraphLoader( - spark, - local_input_path=input_path, - local_output_path=tempdir, - output_prefix=tempdir, - input_prefix=input_path, - data_configs=data_configs_with_label, - num_output_files=1, + loader_config = HeterogeneousLoaderConfig( add_reverse_edges=False, + data_configs=data_configs_with_label, enable_assertions=True, + graph_name="small_heterogeneous_graph", + input_prefix=input_path, + local_input_path=input_path, + local_metadata_output_path=tempdir, + num_output_files=1, + output_prefix=tempdir, + precomputed_transformations={}, + ) + dhgl = DistHeterogeneousGraphLoader( + spark, + loader_config, ) return dhgl @@ -246,6 +263,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade "input_ids": 16, "token_type_ids": 16, "multi": 2, + "state": 3, } }, "efeat_size": {}, @@ -273,6 +291,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade "test_mask", "age", "multi", + "state", "input_ids", "attention_mask", "token_type_ids", @@ -282,6 +301,13 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade for node_type in metadata["node_data"]: assert metadata["node_data"][node_type].keys() == expected_node_data[node_type] + with open( + os.path.join(dghl_loader.output_path, TRANSFORMATIONS_FILENAME), "r", encoding="utf-8" + ) as transformation_file: + transformations_dict = json.load(transformation_file) + + assert "state" in transformations_dict["node_features"] + def test_load_dist_hgl_without_labels( dghl_loader_no_label: DistHeterogeneousGraphLoader, @@ -683,6 +709,7 @@ def test_update_label_properties_multilabel( def test_node_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): + """Test using custom label splits for nodes""" data = [(i,) for i in range(1, 11)] # Create DataFrame @@ -725,6 +752,7 @@ def test_node_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): + """Test using custom label splits for edges""" data = [(i, j) for i in range(1, 4) for j in range(11, 14)] # Create DataFrame edges_df = spark.createDataFrame(data, ["src_str_id", "dst_str_id"])