diff --git a/graphstorm-processing/docker/push_gsprocessing_image.sh b/graphstorm-processing/docker/push_gsprocessing_image.sh index 60e97b0416..be1b29912e 100644 --- a/graphstorm-processing/docker/push_gsprocessing_image.sh +++ b/graphstorm-processing/docker/push_gsprocessing_image.sh @@ -9,7 +9,7 @@ usage() { cat < at the respective Spark version +# https://github.com/apache/spark/blob/v3.5.1/pom.xml#L125 +# replace both Hadoop versions below with the one there +SPARK_HADOOP_VERSIONS = { + "3.5": "3.3.4", + "3.4": "3.3.4", + "3.3": "3.3.2", +} diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hf_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hf_transformation.py index cf53d67993..3d232fcfe9 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hf_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hf_transformation.py @@ -112,6 +112,7 @@ def tokenize(text): logging.warning("The device to run huggingface transformation is %s", device) tokenizer = AutoTokenizer.from_pretrained(hf_model) if max_seq_length > tokenizer.model_max_length: + # TODO: Could we possibly raise this at config time? raise RuntimeError( f"max_seq_length {max_seq_length} is larger " f"than expected {tokenizer.model_max_length}" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/spark_utils.py b/graphstorm-processing/graphstorm_processing/data_transformations/spark_utils.py index 4814c8e64d..e33b5f7cbf 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/spark_utils.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/spark_utils.py @@ -11,12 +11,15 @@ import logging import uuid -from typing import Tuple, Sequence +from typing import Optional, Tuple, Sequence import psutil +import pyspark from pyspark.sql import SparkSession, DataFrame, functions as F +from pyspark.util import VersionUtils from graphstorm_processing import constants +from graphstorm_processing.constants import ExecutionEnv, FilesystemType, SPARK_HADOOP_VERSIONS try: from smspark.bootstrapper import Bootstrapper @@ -31,13 +34,15 @@ def load_instance_type_info(self): return None -def create_spark_session(sm_execution: bool, filesystem_type: str) -> SparkSession: +def create_spark_session( + execution_env: ExecutionEnv, filesystem_type: FilesystemType +) -> SparkSession: """ Create a SparkSession with the appropriate configuration for the execution context. Parameters ---------- - sm_execution + execution_env Whether or not this is being executed on a SageMaker instance. filesystem_type The filesystem type to use. @@ -54,6 +59,69 @@ def create_spark_session(sm_execution: bool, filesystem_type: str) -> SparkSessi processing_job_config = bootstraper.load_processing_job_config() instance_type_info = bootstraper.load_instance_type_info() + spark_builder = ( + SparkSession.builder.appName("GSProcessing") + .config("spark.hadoop.validateOutputSpecs", "false") + .config("spark.logConf", "true") + ) + + if execution_env != ExecutionEnv.EMR_SERVERLESS: + spark_builder = _configure_spark_env( + spark_builder, processing_job_config, instance_type_info + ) + + major, minor = VersionUtils.majorMinorVersion(pyspark.__version__) + hadoop_ver = SPARK_HADOOP_VERSIONS[f"{major}.{minor}"] + # Only used for local testing and container execution + if execution_env == ExecutionEnv.LOCAL and filesystem_type == FilesystemType.S3: + logging.info("Setting up local Spark instance for S3 access...") + spark_builder.config( + "spark.jars.packages", + f"org.apache.hadoop:hadoop-aws:{hadoop_ver}," + f"org.apache.hadoop:hadoop-client:{hadoop_ver}", + ).config("spark.jars.excludes", "com.google.guava:guava").config( + "spark.executor.extraJavaOptions", "-Dcom.amazonaws.services.s3.enableV4=true" + ).config( + "spark.driver.extraJavaOptions", "-Dcom.amazonaws.services.s3.enableV4=true" + ) + + spark = spark_builder.getOrCreate() + + spark.sparkContext.setLogLevel("ERROR") + logger = spark.sparkContext._jvm.org.apache.log4j + logger.LogManager.getLogger("org").setLevel(logger.Level.ERROR) + logger.LogManager.getLogger("py4j").setLevel(logger.Level.ERROR) + spark_logger = logging.getLogger("py4j.java_gateway") + spark_logger.setLevel(logging.ERROR) + + hadoop_config = spark.sparkContext._jsc.hadoopConfiguration() + # This is needed to save RDDs which is the only way to write nested Dataframes into CSV format + # hadoop_config.set( + # "mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter" + # ) + # See https://aws.amazon.com/premiumsupport/knowledge-center/emr-timeout-connection-wait/ + hadoop_config.set("fs.s3.maxConnections", "5000") + hadoop_config.set("fs.s3.maxRetries", "20") + hadoop_config.set("fs.s3a.connection.maximum", "150") + + # Set up auth for local and EMR + if execution_env != ExecutionEnv.SAGEMAKER and filesystem_type == FilesystemType.S3: + hadoop_config.set( + "fs.s3a.aws.credentials.provider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ) + hadoop_config.set("fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") + hadoop_config.set("fs.AbstractFileSystem.s3a.imp", "org.apache.hadoop.fs.s3a.S3A") + spark.sparkContext.setSystemProperty("com.amazonaws.services.s3.enableV4", "true") + + return spark + + +def _configure_spark_env( + spark_builder: SparkSession.Builder, + processing_job_config: Optional[dict], + instance_type_info: Optional[dict], +) -> SparkSession.Builder: if processing_job_config and instance_type_info: instance_type = processing_job_config["ProcessingResources"]["ClusterConfig"][ "InstanceType" @@ -70,8 +138,8 @@ def create_spark_session(sm_execution: bool, filesystem_type: str) -> SparkSessi else: instance_mem_mb = int(psutil.virtual_memory().total / (1024 * 1024)) instance_cores = psutil.cpu_count(logical=True) - logging.warning( - "Failed to detect instance type config. Found total memory: %d MiB and total cores: %d", + logging.info( + "Configuring Spark execution env. Found total memory: %d MiB and total cores: %d", instance_mem_mb, instance_cores, ) @@ -99,57 +167,14 @@ def create_spark_session(sm_execution: bool, filesystem_type: str) -> SparkSessi # Avoid timeout errors due to connection pool starving # Allow sending large results to driver spark_builder = ( - SparkSession.builder.appName("GSProcessing") - .config("spark.hadoop.validateOutputSpecs", "false") - .config("spark.driver.memory", f"{driver_mem_mb}m") + spark_builder.config("spark.driver.memory", f"{driver_mem_mb}m") .config("spark.driver.memoryOverhead", f"{driver_mem_overhead_mb}m") .config("spark.driver.maxResultSize", f"{driver_max_result}m") .config("spark.executor.memory", f"{executor_mem_mb}m") .config("spark.executor.memoryOverhead", f"{executor_mem_overhead_mb}m") - .config("spark.logConf", "true") ) - # TODO: These settings shouldn't be necessary for container execution, - # can we create such a Spark context only for testing? - if not sm_execution and filesystem_type == "s3": - spark_builder.config( - "spark.jars.packages", - "org.apache.hadoop:hadoop-aws:2.10.2," "org.apache.hadoop:hadoop-client:2.10.2", - ).config("spark.jars.excludes", "com.google.guava:guava").config( - "spark.executor.extraJavaOptions", "-Dcom.amazonaws.services.s3.enableV4=true" - ).config( - "spark.driver.extraJavaOptions", "-Dcom.amazonaws.services.s3.enableV4=true" - ) - - spark = spark_builder.getOrCreate() - - spark.sparkContext.setLogLevel("ERROR") - logger = spark.sparkContext._jvm.org.apache.log4j - logger.LogManager.getLogger("org").setLevel(logger.Level.ERROR) - logger.LogManager.getLogger("py4j").setLevel(logger.Level.ERROR) - spark_logger = logging.getLogger("py4j.java_gateway") - spark_logger.setLevel(logging.ERROR) - - hadoop_config = spark.sparkContext._jsc.hadoopConfiguration() - # This is needed to save RDDs which is the only way to write nested Dataframes into CSV format - hadoop_config.set( - "mapred.output.committer.class", "org.apache.hadoop.mapred.FileOutputCommitter" - ) - # See https://aws.amazon.com/premiumsupport/knowledge-center/emr-timeout-connection-wait/ - hadoop_config.set("fs.s3.maxConnections", "5000") - hadoop_config.set("fs.s3a.connection.maximum", "150") - # Only used for local testing and container execution - if not sm_execution and filesystem_type == "s3": - logging.info("Setting up local Spark instance for S3 access...") - hadoop_config.set( - "fs.s3a.aws.credentials.provider", - "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", - ) - hadoop_config.set("fs.s3.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") - hadoop_config.set("fs.AbstractFileSystem.s3a.imp", "org.apache.hadoop.fs.s3a.S3A") - spark.sparkContext.setSystemProperty("com.amazonaws.services.s3.enableV4", "true") - - return spark + return spark_builder def safe_rename_column( diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index 6ef207c750..d97b9ff4b0 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -73,6 +73,7 @@ ParquetRepartitioner, ) from graphstorm_processing.graph_loaders.row_count_utils import verify_metadata_match +from graphstorm_processing.constants import ExecutionEnv, FilesystemType @dataclasses.dataclass @@ -91,12 +92,12 @@ class ExecutorConfig: Prefix for output data. Can be S3 URI or local path. num_output_files : int The number of output files Spark will try to create. - sm_execution : bool - Whether the execution context is a SageMaker container. + execution_env : ExecutionEnv + The kind of execution environment the job will run on. config_filename : str The filename for the configuration file. - filesystem_type : str - The filesystem type, can be 'local' or 's3'. + filesystem_type : FilesystemType + The filesystem type, can be LOCAL or S3 add_reverse_edges : bool Whether to create reverse edges for each edge type. graph_name: str @@ -110,9 +111,9 @@ class ExecutorConfig: input_prefix: str output_prefix: str num_output_files: int - sm_execution: bool + execution_env: ExecutionEnv config_filename: str - filesystem_type: str + filesystem_type: FilesystemType add_reverse_edges: bool graph_name: str do_repartition: bool @@ -153,13 +154,13 @@ def __init__( self.num_output_files = executor_config.num_output_files self.config_filename = executor_config.config_filename self.filesystem_type = executor_config.filesystem_type - self.sm_execution = executor_config.sm_execution + self.execution_env = executor_config.execution_env self.add_reverse_edges = executor_config.add_reverse_edges self.graph_name = executor_config.graph_name self.repartition_on_leader = executor_config.do_repartition # Ensure we have write access to the output path - if self.filesystem_type == "local": + if self.filesystem_type == FilesystemType.LOCAL: if not os.path.exists(self.output_prefix): try: os.makedirs(self.output_prefix, exist_ok=True) @@ -171,18 +172,25 @@ def __init__( s3 = boto3.resource("s3") bucket_name, prefix = s3_utils.extract_bucket_and_key(self.output_prefix) head_bucket_response = s3.meta.client.head_bucket(Bucket=bucket_name) - assert head_bucket_response["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert head_bucket_response["ResponseMetadata"]["HTTPStatusCode"] == 200, ( + f"Could not read objects at S3 output prefix: {self.output_prefix} " + "Check permissions for execution role." + ) bucket_resouce = s3.Bucket(bucket_name) bucket_resouce.put_object(Key=f"{prefix}/test_file.txt", Body=b"test") response = bucket_resouce.delete_objects( Delete={"Objects": [{"Key": f"{prefix}/test_file.txt"}], "Quiet": True} ) - assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200, ( + f"Could not delete objects at S3 output prefix: {self.output_prefix} " + "Check permissions for execution role." + ) graph_conf = os.path.join(self.local_config_path, self.config_filename) with open(graph_conf, "r", encoding="utf-8") as f: dataset_config_dict: Dict[str, Any] = json.load(f) + # Use appropriate config parser depending on file version if "version" in dataset_config_dict: config_version = dataset_config_dict["version"] if config_version == "gsprocessing-v1.0": @@ -215,7 +223,7 @@ def __init__( self.graph_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)["graph"] # Create the Spark session for execution - self.spark = spark_utils.create_spark_session(self.sm_execution, self.filesystem_type) + self.spark = spark_utils.create_spark_session(self.execution_env, self.filesystem_type) def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False): """Upload output files to S3 @@ -227,7 +235,11 @@ def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False force : bool, optional Enforce upload even in SageMaker, by default False """ - if (not self.sm_execution and self.filesystem_type == "s3") or force: + if ( + not self.execution_env == ExecutionEnv.SAGEMAKER + and self.filesystem_type == FilesystemType.S3 + ) or force: + # Output files need to be manually uploaded to S3 when not on SM bucket, s3_prefix = s3_utils.extract_bucket_and_key(self.output_prefix) s3 = boto3.resource("s3") @@ -260,7 +272,7 @@ def run(self) -> None: enable_assertions=False, graph_name=self.graph_name, ) - graph_meta_dict = loader.load() + graph_meta_dict, timers_dict = loader.load() t1 = time.time() logging.info("Time to transform data for distributed partitioning: %s sec", t1 - t0) @@ -276,29 +288,23 @@ def run(self) -> None: streaming_repartitioning=False, ) + repartition_start = time.perf_counter() if all_match: logging.info( "All file row counts match, applying Parquet metadata modification on Spark leader." ) modify_flat_array_metadata(graph_meta_dict, repartitioner) - logging.info("Data are now prepared for processing by the DistPart Partition pipeline.") + # modify_flat_array_metadata modifies file metadata in-place, + # so the original meta dict still applies + updated_metadata = graph_meta_dict else: if self.repartition_on_leader: logging.info("Attempting to repartition graph data on Spark leader...") try: # Upload existing output files before trying to re-partition - self._upload_output_files(loader, force=True) + if self.filesystem_type == FilesystemType.S3: + self._upload_output_files(loader, force=True) updated_metadata = repartition_files(graph_meta_dict, repartitioner) - with open( - os.path.join(loader.output_path, "updated_row_counts_metadata.json"), - "w", - encoding="utf-8", - ) as f: - json.dump(updated_metadata, f, indent=4) - f.flush() - logging.info( - "Data are now prepared for processing by the DistPart Partition pipeline." - ) except Exception as e: # pylint: disable=broad-exception-caught # If an error happens during re-partition, we don't want to fail the entire # gs-processing job, so we just post an error and continue @@ -310,10 +316,35 @@ def run(self) -> None: ) else: logging.warning("gs-repartition will need to run as a follow-up job on the data!") + timers_dict["repartition"] = time.perf_counter() - repartition_start + + # If any of the metadata modification took place, write an updated metadata file + if updated_metadata: + updated_meta_path = os.path.join(loader.output_path, "updated_row_counts_metadata.json") + with open( + updated_meta_path, + "w", + encoding="utf-8", + ) as f: + json.dump(updated_metadata, f, indent=4) + f.flush() + logging.info("Updated metadata written to %s", updated_meta_path) + logging.info( + "Data are now prepared for processing by the Distributed Partition pipeline." + ) + + with open( + os.path.join(self.local_output_path, "perf_counters.json"), "w", encoding="utf-8" + ) as f: + sorted_timers = dict(sorted(timers_dict.items(), key=lambda x: x[1], reverse=True)) + json.dump(sorted_timers, f, indent=4) # This is used to upload the output JSON files to S3 on non-SageMaker runs, # since we can't rely on SageMaker to do it - self._upload_output_files(loader, force=False) + if self.filesystem_type == FilesystemType.S3: + self._upload_output_files( + loader, force=not self.execution_env == ExecutionEnv.SAGEMAKER + ) def parse_args() -> argparse.Namespace: @@ -387,11 +418,16 @@ def main(): gsprocessing_args = GSProcessingArguments(**vars(parse_args())) logging.basicConfig( level=gsprocessing_args.log_level, - format="%(asctime)s %(levelname)-8s %(message)s", + format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s", ) # Determine if we're running within a SageMaker container - is_sagemaker_execution = os.path.exists("/opt/ml/config/processingjobconfig.json") + if os.path.exists("/opt/ml/config/processingjobconfig.json"): + execution_env = ExecutionEnv.SAGEMAKER + elif os.path.exists("/emr-serverless-config.json"): + execution_env = ExecutionEnv.EMR_SERVERLESS + else: + execution_env = ExecutionEnv.LOCAL if gsprocessing_args.input_prefix.startswith("s3://"): assert gsprocessing_args.output_prefix.startswith("s3://"), ( @@ -400,7 +436,7 @@ def main(): f"and output: '{gsprocessing_args.output_prefix}'." ) - filesystem_type = "s3" + filesystem_type = FilesystemType.S3 else: # Ensure input and output prefixes exist and convert to absolute paths gsprocessing_args.input_prefix = str( @@ -411,15 +447,14 @@ def main(): gsprocessing_args.output_prefix = str( Path(gsprocessing_args.output_prefix).resolve(strict=True) ) - filesystem_type = "local" + filesystem_type = FilesystemType.LOCAL # local input location for config file and execution script - if is_sagemaker_execution: + if execution_env == ExecutionEnv.SAGEMAKER: local_config_path = "/opt/ml/processing/input/data" else: - # If not on SageMaker, assume that we are running in a - # native env with local input or Docker execution with S3 input - if filesystem_type == "local": + # When not running on SM, we need to manually pull the config file from S3 + if filesystem_type == FilesystemType.LOCAL: local_config_path = gsprocessing_args.input_prefix else: tempdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with @@ -433,7 +468,7 @@ def main(): f"{input_s3_prefix}/{gsprocessing_args.config_filename}", os.path.join(tempdir.name, gsprocessing_args.config_filename), ) - except botocore.exceptions.ClientError as e: + except botocore.exceptions.ClientError as e: # type: ignore raise RuntimeError( "Unable to download config file at" f"s3://{input_bucket}/{input_s3_prefix}/" @@ -442,10 +477,10 @@ def main(): local_config_path = tempdir.name # local output location for metadata files - if is_sagemaker_execution: + if execution_env == ExecutionEnv.SAGEMAKER: local_output_path = "/opt/ml/processing/output" else: - if filesystem_type == "local": + if filesystem_type == FilesystemType.LOCAL: local_output_path = gsprocessing_args.output_prefix else: # Only needed for local execution with S3 data @@ -454,6 +489,11 @@ def main(): if not gsprocessing_args.num_output_files: gsprocessing_args.num_output_files = -1 + # Save arguments to file for posterity + with open(os.path.join(local_output_path, "launch_arguments.json"), "w", encoding="utf-8") as f: + json.dump(dataclasses.asdict(gsprocessing_args), f, indent=4) + f.flush() + executor_configuration = ExecutorConfig( local_config_path=local_config_path, local_output_path=local_output_path, @@ -461,7 +501,7 @@ def main(): output_prefix=gsprocessing_args.output_prefix, num_output_files=gsprocessing_args.num_output_files, config_filename=gsprocessing_args.config_filename, - sm_execution=is_sagemaker_execution, + execution_env=execution_env, filesystem_type=filesystem_type, add_reverse_edges=gsprocessing_args.add_reverse_edges, graph_name=gsprocessing_args.graph_name, @@ -472,25 +512,6 @@ def main(): dist_executor.run() - # Save arguments to file for posterity - with open(os.path.join(local_output_path, "launch_arguments.json"), "w", encoding="utf-8") as f: - json.dump(dataclasses.asdict(gsprocessing_args), f, indent=4) - f.flush() - - # In SageMaker execution, all files under `local_output_path` get automatically - # uploaded to S3 at the end of the job. Otherwise, we need to upload - # all output files manually. - if not is_sagemaker_execution and filesystem_type == "s3": - output_bucket, output_s3_prefix = s3_utils.extract_bucket_and_key( - gsprocessing_args.output_prefix - ) - s3 = boto3.resource("s3") - s3.meta.client.upload_file( - os.path.join(local_output_path, "launch_arguments.json"), - output_bucket, - f"{output_s3_prefix}/launch_arguments.json", - ) - if __name__ == "__main__": main() diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 8afe45e97c..79fc14d454 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -16,9 +16,9 @@ import json import logging +import math import numbers import os -import math from collections import Counter, defaultdict from time import perf_counter from typing import Any, Dict, Mapping, Optional, Sequence, Set, Tuple @@ -145,7 +145,7 @@ def __init__( def process_and_write_graph_data( self, data_configs: Mapping[str, Sequence[StructureConfig]] - ) -> Dict: + ) -> tuple[dict, dict]: """Process and encode all graph data. Extracts and encodes graph structure before writing to storage, then applies pre-processing @@ -162,19 +162,22 @@ def process_and_write_graph_data( Returns ------- - metadata_dict : Dict - Dictionary of metadata for the graph, in "chunked-graph" + tuple[dict, dict] + A tuple with two dictionaries: + The first is the dictionary of metadata for the graph, in "chunked-graph" format, with additional keys. For chunked graph format see https://docs.dgl.ai/guide/distributed-preprocessing.html#specification The dict also contains a "raw_id_mappings" key, which is a dict of dicts, one for each node type. Each entry contains files information - about the raw-to-integet ID mapping for each node. + about the raw-to-integer ID mapping for each node. The returned value also contains an additional dict of dicts, "graph_info" which contains additional information about the graph in a more readable format. + + The second is a dict of duration values for each part of the execution. """ # TODO: See if it's better to return some data structure # for the followup steps instead of just have side-effects @@ -243,8 +246,7 @@ def process_and_write_graph_data( metadata_dict["graph_info"] = self._finalize_graphinfo_dict(metadata_dict) - # The metadata dict is written to disk as a JSON file, - # SageMaker takes care of uploading it to S3 + # The metadata dict is written to disk as a JSON file with open(os.path.join(self.output_path, "metadata.json"), "w", encoding="utf-8") as f: json.dump(metadata_dict, f, indent=4) @@ -256,13 +258,9 @@ def process_and_write_graph_data( self.timers["process_and_write_graph_data"] = perf_counter() - process_start_time - with open(os.path.join(self.output_path, "perf_counters.json"), "w", encoding="utf-8") as f: - sorted_timers = dict(sorted(self.timers.items(), key=lambda x: x[1], reverse=True)) - json.dump(sorted_timers, f, indent=4) - logging.info("Finished Distributed Graph Processing ...") - return metadata_dict + return metadata_dict, self.timers @staticmethod def _at_least_one_label_exists(data_configs: Mapping[str, Sequence[StructureConfig]]) -> bool: @@ -877,6 +875,9 @@ def process_node_data(self, node_configs: Sequence[NodeConfig]) -> Dict: if self.enable_assertions: nodes_df_count = nodes_df.count() mapping_df_count = mapping_df.count() + logging.warning( + "Node mapping count for node type %s: %d", node_type, mapping_df_count + ) assert nodes_df_count == mapping_df_count, ( f"Nodes df count ({nodes_df_count}) does not match " f"mapping df count ({mapping_df_count})" @@ -959,6 +960,7 @@ def _process_node_features( transformer = DistFeatureTransformer(feat_conf) transformed_feature_df = transformer.apply_transformation(nodes_df) + transformed_feature_df.cache() def write_processed_feature(feat_name, single_feature_df, node_type, transformer_name): feature_output_path = os.path.join( @@ -1011,6 +1013,10 @@ def write_processed_feature(feat_name, single_feature_df, node_type, transformer node_type, transformer.get_transformation_name(), ) + + # Unpersist and move on to next feature + transformed_feature_df.unpersist() + return node_type_feature_metadata, ntype_feat_sizes def _process_node_labels( @@ -1151,14 +1157,20 @@ def write_edge_structure( incoming_edge_count = edge_df.count() intermediate_edge_count = edge_df_with_int_src.count() if incoming_edge_count != intermediate_edge_count: + distinct_nodes_src = edge_df.select(src_col).distinct().count() logging.fatal( - "Incoming and outgoing edge counts do not match for " - "%s when joining %s with src_str_id! " - "%d in != %d out", + ( + "Incoming and outgoing edge counts do not match for " + "%s when joining %s with src_str_id! " + "%d in != %d out" + "Edge had %d distinct src nodes of type %s" + ), edge_type, src_col, incoming_edge_count, intermediate_edge_count, + distinct_nodes_src, + src_ntype, ) if src_ntype == dst_ntype: @@ -1196,7 +1208,9 @@ def write_edge_structure( edge_df_with_int_ids_and_all_features = edge_df_with_int_ids edge_df_with_only_int_ids = edge_df_with_int_ids.select(["src_int_id", "dst_int_id"]) - edge_structure_path = os.path.join(self.output_prefix, f"edges/{edge_type}") + edge_structure_path = os.path.join( + self.output_prefix, f"edges/{edge_type.replace(':', '_')}" + ) logging.info("Writing edge structure for edge type %s...", edge_type) if self.add_reverse_edges: edge_df_with_only_int_ids.cache() @@ -1205,7 +1219,7 @@ def write_edge_structure( if self.add_reverse_edges: reversed_edges = edge_df_with_only_int_ids.select("dst_int_id", "src_int_id") reversed_edge_structure_path = os.path.join( - self.output_prefix, f"edges/{rev_edge_type}" + self.output_prefix, f"edges/{rev_edge_type.replace(':', '_')}" ) logging.info("Writing edge structure for reverse edge type %s...", rev_edge_type) reverse_path_list = self._write_df(reversed_edges, reversed_edge_structure_path) @@ -1216,11 +1230,17 @@ def write_edge_structure( if self.enable_assertions: outgoing_edge_count = edge_df_with_only_int_ids.count() if incoming_edge_count != outgoing_edge_count: + distinct_nodes_dst = edge_df.select(dst_col).distinct().count() logging.fatal( - "Incoming and outgoing edge counts do not match for '%s'! %d in != %d out", + ( + "Incoming and outgoing edge counts do not match for '%s'! %d in != %d out" + "Edge had %d distinct dst nodes of type %s" + ), edge_type, incoming_edge_count, outgoing_edge_count, + distinct_nodes_dst, + dst_ntype, ) return edge_df_with_int_ids_and_all_features, path_list, reverse_path_list @@ -1399,10 +1419,11 @@ def _process_edge_features( transformer = DistFeatureTransformer(feat_conf) transformed_feature_df = transformer.apply_transformation(edges_df) + transformed_feature_df.cache() - def process_feature(self, feat_name, single_feature_df, edge_type, transformer_name): + def write_feature(self, feat_name, single_feature_df, edge_type, transformer_name): feature_output_path = os.path.join( - self.output_prefix, f"edge_data/{edge_type}-{feat_name}" + self.output_prefix, f"edge_data/{edge_type.replace(':', '_')}-{feat_name}" ) logging.info( "Writing output for feat_name: '%s' to %s", feat_name, feature_output_path @@ -1435,7 +1456,7 @@ def process_feature(self, feat_name, single_feature_df, edge_type, transformer_n ): for bert_feat_name in ["input_ids", "attention_mask", "token_type_ids"]: single_feature_df = transformed_feature_df.select(bert_feat_name) - process_feature( + write_feature( self, bert_feat_name, single_feature_df, @@ -1446,7 +1467,7 @@ def process_feature(self, feat_name, single_feature_df, edge_type, transformer_n single_feature_df = transformed_feature_df.select(feat_col).withColumnRenamed( feat_col, feat_name ) - process_feature( + write_feature( self, feat_name, single_feature_df, @@ -1454,6 +1475,9 @@ def process_feature(self, feat_name, single_feature_df, edge_type, transformer_n transformer.get_transformation_name(), ) + # Unpersist and move on to next feature + transformed_feature_df.unpersist() + return edge_feature_metadata_dicts, etype_feat_sizes def _process_edge_labels( @@ -1509,7 +1533,8 @@ def _process_edge_labels( transformed_label = edge_label_loader.process_label(edges_df) label_output_path = os.path.join( - self.output_prefix, f"edge_data/{edge_type}-label-{rel_type_prefix}" + self.output_prefix, + f"edge_data/{edge_type.replace(':', '_')}-label-{rel_type_prefix}", ) path_list = self._write_df(transformed_label, label_output_path) @@ -1528,7 +1553,9 @@ def _process_edge_labels( rel_type_prefix, ) - split_masks_output_prefix = os.path.join(self.output_prefix, f"edge_data/{edge_type}") + split_masks_output_prefix = os.path.join( + self.output_prefix, f"edge_data/{edge_type.replace(':', '_')}" + ) logging.info("Creating train/test/val split for edge type %s...", edge_type) if label_conf.split_rate: split_rates = SplitRates( @@ -1758,6 +1785,8 @@ def multinomial_sample(label_col: str) -> Sequence[int]: # to create one-hot vector indicating train/test/val membership input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy") int_group_df = input_df.select(split_group(input_col).alias(group_col_name)) + + # We cache because we re-use this DF 3 times int_group_df.cache() train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias("train_mask")) val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias("val_mask")) @@ -1847,5 +1876,5 @@ def process_custom_mask_df(input_df, split_file, mask_type): ) return train_mask_df, val_mask_df, test_mask_df - def load(self) -> Dict: + def load(self) -> tuple[dict, dict]: return self.process_and_write_graph_data(self._data_configs) diff --git a/graphstorm-processing/graphstorm_processing/repartition_files.py b/graphstorm-processing/graphstorm_processing/repartition_files.py index f74fb3b1b4..d74fb00065 100644 --- a/graphstorm-processing/graphstorm_processing/repartition_files.py +++ b/graphstorm-processing/graphstorm_processing/repartition_files.py @@ -53,6 +53,7 @@ from graphstorm_processing.data_transformations import s3_utils from graphstorm_processing.graph_loaders.row_count_utils import ParquetRowCounter +from graphstorm_processing.constants import FilesystemType NUM_WRITE_THREADS = 16 @@ -64,8 +65,8 @@ class ParquetRepartitioner: ---------- input_prefix : str Prefix for the input files. - filesystem_type : str - The type of the filesystem being used. Should be 's3' or 'local'. + filesystem_type : FilesystemType + The type of the filesystem being used. Should be S3 or LOCAL. region : Optional[str] Region to be used for S3 interactions, by default None. verify_outputs : bool, optional @@ -79,20 +80,15 @@ class ParquetRepartitioner: def __init__( self, input_prefix: str, - filesystem_type: str, + filesystem_type: FilesystemType, region: Optional[str] = None, verify_outputs: bool = True, streaming_repartitioning=False, ): - assert filesystem_type in [ - "local", - "s3", - ], f"filesystem_type must be one of 'local' or 's3', got {filesystem_type}" - # Pyarrow expects paths of the form "bucket/path/to/file", so we strip the s3:// prefix self.input_prefix = input_prefix[5:] if input_prefix.startswith("s3://") else input_prefix self.filesystem_type = filesystem_type - if self.filesystem_type == "s3": + if self.filesystem_type == FilesystemType.S3: self.bucket = self.input_prefix.split("/")[1] self.pyarrow_fs = fs.S3FileSystem( region=region, retry_strategy=fs.AwsDefaultS3RetryStrategy(max_attempts=10) @@ -121,7 +117,7 @@ def read_dataset_from_relative_path(self, relative_path: str) -> ds.Dataset: dataset_relative_path = relative_path dataset_location = os.path.join(self.input_prefix, dataset_relative_path) - return ds.dataset(dataset_location, filesystem=self.pyarrow_fs) + return ds.dataset(dataset_location, filesystem=self.pyarrow_fs, exclude_invalid_files=False) def read_parquet_from_relative_path(self, relative_path: str) -> pyarrow.Table: """ @@ -140,7 +136,7 @@ def write_parquet_to_relative_path( # TODO: Might be easier to update the output file list every time # this is called to ensure consistency? file_path = os.path.join(self.input_prefix, relative_path) - if self.filesystem_type == "local": + if self.filesystem_type == FilesystemType.LOCAL: os.makedirs(Path(file_path).parent, exist_ok=True) pq.write_table(table, file_path, filesystem=self.pyarrow_fs, compression="snappy") if self.verify_outputs: @@ -1167,13 +1163,13 @@ def main(): logging.basicConfig(level=getattr(logging, repartition_config.log_level.upper(), None)) if repartition_config.input_prefix.startswith("s3://"): - filesystem_type = "s3" + filesystem_type = FilesystemType.S3 else: input_prefix = str(Path(repartition_config.input_prefix).resolve(strict=True)) - filesystem_type = "local" + filesystem_type = FilesystemType.LOCAL # Trim trailing '/' from S3 URI - if filesystem_type == "s3": + if filesystem_type == FilesystemType.S3: input_prefix = s3_utils.s3_path_remove_trailing(repartition_config.input_prefix) logging.info( @@ -1190,7 +1186,7 @@ def main(): # Get the metadata file region = None - if filesystem_type == "s3": + if filesystem_type == FilesystemType.S3: bucket = input_prefix.split("/")[2] s3_key_prefix = input_prefix.split("/", 3)[3] region = s3_utils.get_bucket_region(bucket) @@ -1224,7 +1220,7 @@ def main(): metafile.flush() # Upload the updated metadata file to S3 - if filesystem_type == "s3": + if filesystem_type == FilesystemType.S3: s3_client.upload_file( metafile.name, bucket, diff --git a/graphstorm-processing/tests/test_repartition_files.py b/graphstorm-processing/tests/test_repartition_files.py index 98e3e084fb..23aa8b4d71 100644 --- a/graphstorm-processing/tests/test_repartition_files.py +++ b/graphstorm-processing/tests/test_repartition_files.py @@ -29,6 +29,7 @@ from graphstorm_processing.repartition_files import ParquetRepartitioner from graphstorm_processing import repartition_files +from graphstorm_processing.constants import FilesystemType _ROOT = os.path.abspath(os.path.dirname(__file__)) DUMMY_PREFIX = "s3://dummy_bucket/dummy_prefix" @@ -186,7 +187,7 @@ def test_repartition_functions(desired_counts: List[int], partition_function_nam """Test the repartition functions, streaming and in-memory""" assert sum(desired_counts) == 50 - my_partitioner = ParquetRepartitioner(TEMP_DATA_PREFIX, filesystem_type="local") + my_partitioner = ParquetRepartitioner(TEMP_DATA_PREFIX, filesystem_type=FilesystemType.LOCAL) metadata_path = os.path.join(TEMP_DATA_PREFIX, "partitioned_metadata.json")