diff --git a/python/graphstorm/sagemaker/sagemaker_gb_convert.py b/python/graphstorm/sagemaker/sagemaker_gb_convert.py index 9716b90c2..337a9400d 100644 --- a/python/graphstorm/sagemaker/sagemaker_gb_convert.py +++ b/python/graphstorm/sagemaker/sagemaker_gb_convert.py @@ -19,6 +19,7 @@ import logging import os import time +from collections import defaultdict from packaging import version import boto3 @@ -59,7 +60,7 @@ def run_gb_convert(s3_output_path: str, local_dist_part_config: str, njobs: int) f"but DGL version was {dgl_version}. " ) - boto_session = boto3.Session(region_name=os.environ['AWS_REGION']) + boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) sagemaker_session = sagemaker.Session(boto_session=boto_session) # Run the actual conversion, this will create the fused_csc_sampling_graph.pt @@ -71,25 +72,41 @@ def run_gb_convert(s3_output_path: str, local_dist_part_config: str, njobs: int) # Iterate through the partition data and upload only the modified/new # files to the corresponding path on S3 upload_start = time.time() + fused_files_exist = defaultdict(lambda: False) for root, _, files in os.walk(os.path.dirname(local_dist_part_config)): for file in files: - if file.endswith(("fused_csc_sampling_graph.pt", ".json")): + if file.endswith("fused_csc_sampling_graph.pt"): + partition_id = root.split("/")[-1] + # Set fused file existence to true for this partition + fused_files_exist[partition_id] = True + # Partition data need to be uploaded to partition-id dirs + s3_path = os.path.join(s3_output_path, f"{partition_id}") + elif file.endswith(".json"): + # Partition output metadata file needs to be uploaded to root dir + s3_path = s3_output_path + else: + # We skip other files partition_id = root.split("/")[-1] if "part" in partition_id: - # Partition data need to be uploaded to partition-id dirs - assert file.endswith("fused_csc_sampling_graph.pt") - s3_path = os.path.join(s3_output_path, f"{partition_id}") - else: - # Metadata file needs to be uploaded to root dir - assert file.endswith(".json") - s3_path = s3_output_path - logging.info( - "Uploading local %s to %s", - os.path.join(root, file), - s3_path) - S3Uploader.upload( - local_path=os.path.join(root, file), - desired_s3_uri=s3_path, - sagemaker_session=sagemaker_session - ) + # Set file existence to False only if + # we haven't encountered a fused file already + fused_files_exist[partition_id] = ( + False or fused_files_exist[partition_id] + ) + continue + + logging.info("Uploading local %s to %s", os.path.join(root, file), s3_path) + S3Uploader.upload( + local_path=os.path.join(root, file), + desired_s3_uri=s3_path, + sagemaker_session=sagemaker_session, + ) + + for partition_id, fused_file_exists in fused_files_exist.items(): + if not fused_file_exists: + raise RuntimeError( + f"Partition {partition_id} did not have " + "a fused_csc_sampling_graph.pt file." + ) + logging.info("Uploading took %f sec.", time.time() - upload_start)