Skip to content

Commit

Permalink
Add checks for existence of fused file per partition
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Dec 19, 2024
1 parent b04fe58 commit c9e647a
Showing 1 changed file with 35 additions and 18 deletions.
53 changes: 35 additions & 18 deletions python/graphstorm/sagemaker/sagemaker_gb_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import os
import time
from collections import defaultdict
from packaging import version

import boto3
Expand Down Expand Up @@ -59,7 +60,7 @@ def run_gb_convert(s3_output_path: str, local_dist_part_config: str, njobs: int)
f"but DGL version was {dgl_version}. "
)

boto_session = boto3.Session(region_name=os.environ['AWS_REGION'])
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
sagemaker_session = sagemaker.Session(boto_session=boto_session)

# Run the actual conversion, this will create the fused_csc_sampling_graph.pt
Expand All @@ -71,25 +72,41 @@ def run_gb_convert(s3_output_path: str, local_dist_part_config: str, njobs: int)
# Iterate through the partition data and upload only the modified/new
# files to the corresponding path on S3
upload_start = time.time()
fused_files_exist = defaultdict(lambda: False)
for root, _, files in os.walk(os.path.dirname(local_dist_part_config)):
for file in files:
if file.endswith(("fused_csc_sampling_graph.pt", ".json")):
if file.endswith("fused_csc_sampling_graph.pt"):
partition_id = root.split("/")[-1]
# Set fused file existence to true for this partition
fused_files_exist[partition_id] = True
# Partition data need to be uploaded to partition-id dirs
s3_path = os.path.join(s3_output_path, f"{partition_id}")
elif file.endswith(".json"):
# Partition output metadata file needs to be uploaded to root dir
s3_path = s3_output_path
else:
# We skip other files
partition_id = root.split("/")[-1]
if "part" in partition_id:
# Partition data need to be uploaded to partition-id dirs
assert file.endswith("fused_csc_sampling_graph.pt")
s3_path = os.path.join(s3_output_path, f"{partition_id}")
else:
# Metadata file needs to be uploaded to root dir
assert file.endswith(".json")
s3_path = s3_output_path
logging.info(
"Uploading local %s to %s",
os.path.join(root, file),
s3_path)
S3Uploader.upload(
local_path=os.path.join(root, file),
desired_s3_uri=s3_path,
sagemaker_session=sagemaker_session
)
# Set file existence to False only if
# we haven't encountered a fused file already
fused_files_exist[partition_id] = (
False or fused_files_exist[partition_id]
)
continue

logging.info("Uploading local %s to %s", os.path.join(root, file), s3_path)
S3Uploader.upload(
local_path=os.path.join(root, file),
desired_s3_uri=s3_path,
sagemaker_session=sagemaker_session,
)

for partition_id, fused_file_exists in fused_files_exist.items():
if not fused_file_exists:
raise RuntimeError(
f"Partition {partition_id} did not have "
"a fused_csc_sampling_graph.pt file."
)

logging.info("Uploading took %f sec.", time.time() - upload_start)

0 comments on commit c9e647a

Please sign in to comment.