Skip to content

Commit

Permalink
Use argument-specific path to facilitate step caching, pass partition…
Browse files Browse the repository at this point in the history
… algorithm to gconstruct
  • Loading branch information
thvasilo committed Dec 17, 2024
1 parent d078d0b commit 1e08286
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 49 deletions.
31 changes: 13 additions & 18 deletions sagemaker/pipeline/create_sm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,15 @@ def __init__(
)

# Build up the output prefix
# TODO: Using PIPELINE_EXECUTION_ID in the output path invalidates cached results,
# maybe have the output path be static between executions (but unique per pipeline)?
# One option might be to use a hash of the execution parameters dict and
# add that to the prefix?
# Could be passed as another parameter to the pipeline
# We use a hash of the execution parameters dict and
# add that to the prefix to have consistent intermediate paths between executions
# that share all the same parameters.
self.output_subpath = Join(
on="/",
values=[
self.output_prefix_param,
self._get_pipeline_name(args),
ExecutionVariables.PIPELINE_EXECUTION_ID,
self.execution_subpath_param,
],
)
self.train_infer_instance = (
Expand All @@ -92,9 +90,7 @@ def __init__(
self.train_infer_image = (
args.aws_config.graphstorm_pytorch_cpu_image_url
if self.args.instance_config.train_on_cpu
else
args.aws_config.graphstorm_pytorch_gpu_image_url

else args.aws_config.graphstorm_pytorch_gpu_image_url
)

def _get_or_create_pipeline_session(
Expand Down Expand Up @@ -195,6 +191,9 @@ def _create_pipeline_parameters(self, args: PipelineArgs):
"InstanceVolumeSizeGB",
args.instance_config.volume_size_gb,
)
self.execution_subpath_param = self._create_string_parameter(
"ExecutionSubpath", args.get_hash_hex()
)
self.graphconstruct_config_param = self._create_string_parameter(
"GraphConstructConfigFile", args.graph_construction_config.config_filename
)
Expand Down Expand Up @@ -310,15 +309,13 @@ def _create_gconstruct_step(self, args: PipelineArgs) -> ProcessingStep:
gc_local_input_path = "/opt/ml/processing/input"
# GConstruct should always be the first step and start with the source data
gc_proc_input = ProcessingInput(
source=self.input_data_param,
destination=gc_local_input_path,
s3_input_mode='File',
source=self.input_data_param, destination=gc_local_input_path
)
gc_local_output_path = "/opt/ml/processing/output"
gc_proc_output = ProcessingOutput(
source=gc_local_output_path,
destination=gconstruct_s3_output,
output_name=self.graph_name_param,
output_name=f"{self.graph_name_param}-gconstruct",
)

gconstruct_arguments = [
Expand All @@ -332,6 +329,8 @@ def _create_gconstruct_step(self, args: PipelineArgs) -> ProcessingStep:
self.graph_name_param,
"--num-parts",
self.instance_count_param.to_string(),
"--part-method",
self.partition_algorithm_param,
]

# TODO: Make this a pipeline parameter?
Expand Down Expand Up @@ -360,7 +359,6 @@ def _create_gconstruct_step(self, args: PipelineArgs) -> ProcessingStep:

def _create_gsprocessing_step(self, args: PipelineArgs) -> ProcessingStep:
# Implementation for GSProcessing step
# TODO: Add volume size
pyspark_processor = PySparkProcessor(
role=args.aws_config.role,
instance_type=args.instance_config.graph_construction_instance_type,
Expand Down Expand Up @@ -400,8 +398,6 @@ def _create_gsprocessing_step(self, args: PipelineArgs) -> ProcessingStep:
gsprocessing_output,
"--do-repartition",
"True",
"--add-reverse-edges",
"True",
"--log-level",
args.task_config.log_level,
]
Expand All @@ -417,7 +413,7 @@ def _create_gsprocessing_step(self, args: PipelineArgs) -> ProcessingStep:
destination="/opt/ml/processing/input/data",
)
gsprocessing_meta_output = ProcessingOutput(
output_name="metadata",
output_name="partition-input-metadata",
destination=gsprocessing_output,
source="/opt/ml/processing/output",
)
Expand Down Expand Up @@ -522,7 +518,6 @@ def _create_gb_convert_step(self, args: PipelineArgs) -> ProcessingStep:
input_name="dist_graph_s3_input",
destination="/opt/ml/processing/dist_graph/",
source=self.next_step_data_input,
# GraphBolt conversion requires File mode
s3_input_mode="File",
)
],
Expand Down
88 changes: 74 additions & 14 deletions sagemaker/pipeline/execute_sm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
"""

import argparse
import os
import sys
import warnings

import boto3
import psutil
Expand All @@ -40,11 +42,16 @@ def parse_args():
required=True,
help="Name of the pipeline to execute. Required.",
)
parser.add_argument("--region", type=str, required=False,
help="AWS region. Required for SageMaker execution.")
parser.add_argument(
"--async-execution", action="store_true",
help="Run pipeline asynchronously on SageMaker, return after printing execution ARN."
"--region",
type=str,
required=False,
help="AWS region. Required for SageMaker execution.",
)
parser.add_argument(
"--async-execution",
action="store_true",
help="Run pipeline asynchronously on SageMaker, return after printing execution ARN.",
)
parser.add_argument(
"--local-execution",
Expand All @@ -60,10 +67,9 @@ def parse_args():
),
)


overrides = parser.add_argument_group(
"Pipeline overrides",
"Override default pipeline parameters at execution time.")
"Pipeline overrides", "Override default pipeline parameters at execution time."
)

# Optional override parameters
overrides.add_argument("--instance-count", type=int, help="Override instance count")
Expand All @@ -90,7 +96,9 @@ def parse_args():
help="Override partition algorithm",
)
overrides.add_argument("--graph-name", type=str, help="Override graph name")
overrides.add_argument("--num-trainers", type=int, help="Override number of trainers")
overrides.add_argument(
"--num-trainers", type=int, help="Override number of trainers"
)
overrides.add_argument(
"--use-graphbolt",
type=str,
Expand All @@ -110,6 +118,14 @@ def parse_args():
overrides.add_argument(
"--inference-model-snapshot", type=str, help="Override inference model snapshot"
)
overrides.add_argument(
"--execution-subpath",
type=str,
help=(
"Override execution subpath. "
"By default it's derived from a hash of the input arguments"
),
)

return parser.parse_args()

Expand All @@ -118,15 +134,16 @@ def main():
"""Execute GraphStorm SageMaker pipeline"""
args = parse_args()

pipeline_deploy_args = load_pipeline_args(
args.pipeline_args_json_file or f"{args.pipeline_name}-pipeline-args.json"
)
deploy_time_hash = pipeline_deploy_args.get_hash_hex()

if args.local_execution:
# Use local pipeline and session
pipeline_args = load_pipeline_args(
args.pipeline_args_json_file or f"{args.pipeline_name}-pipeline-args.json"
)

local_session = LocalPipelineSession()
pipeline_generator = GraphStormPipelineGenerator(
pipeline_args, input_session=local_session
pipeline_deploy_args, input_session=local_session
)
# Set shared memory to half the host's size, as SM does
instance_mem_mb = int(psutil.virtual_memory().total // (1024 * 1024))
Expand All @@ -135,7 +152,7 @@ def main():
}
pipeline = pipeline_generator.create_pipeline()
pipeline.sagemaker_session = local_session
pipeline.create(role_arn=pipeline_args.aws_config.role)
pipeline.create(role_arn=pipeline_deploy_args.aws_config.role)
else:
assert args.region, "Need to provide --region for remote SageMaker execution"
boto_session = boto3.Session(region_name=args.region)
Expand All @@ -147,34 +164,77 @@ def main():
execution_params = {}
if args.instance_count is not None:
execution_params["InstanceCount"] = args.instance_count
pipeline_deploy_args.instance_config.train_infer_instance_count = (
args.instance_count
)
if args.cpu_instance_type:
execution_params["CPUInstanceType"] = args.cpu_instance_type
pipeline_deploy_args.instance_config.cpu_instance_type = args.cpu_instance_type
if args.gpu_instance_type:
execution_params["GPUInstanceType"] = args.gpu_instance_type
pipeline_deploy_args.instance_config.gpu_instance_type = args.gpu_instance_type
if args.graphconstruct_instance_type:
execution_params["GraphConstructInstanceType"] = (
args.graphconstruct_instance_type
)
pipeline_deploy_args.instance_config.graph_construction_instance_type = (
args.graphconstruct_instance_type
)
if args.graphconstruct_config_file:
execution_params["GraphConstructConfigFile"] = args.graphconstruct_config_file
pipeline_deploy_args.graph_construction_config.config_filename = (
args.graphconstruct_config_file
)
if args.partition_algorithm:
execution_params["PartitionAlgorithm"] = args.partition_algorithm
pipeline_deploy_args.partition_config.partition_algorithm = (
args.partition_algorithm
)
if args.graph_name:
execution_params["GraphName"] = args.graph_name
pipeline_deploy_args.task_config.graph_name = args.graph_name
if args.num_trainers is not None:
execution_params["NumTrainers"] = args.num_trainers
pipeline_deploy_args.training_config.num_trainers = args.num_trainers
if args.use_graphbolt:
execution_params["UseGraphBolt"] = args.use_graphbolt
pipeline_deploy_args.training_config.use_graphbolt_str = args.use_graphbolt
if args.input_data:
execution_params["InputData"] = args.input_data
pipeline_deploy_args.task_config.input_data_s3 = args.input_data
if args.output_prefix:
execution_params["OutputPrefix"] = args.output_prefix
pipeline_deploy_args.task_config.output_prefix = args.output_prefix
if args.train_yaml_file:
execution_params["TrainConfigFile"] = args.train_yaml_file
pipeline_deploy_args.training_config.train_yaml_file = args.train_yaml_file
if args.inference_yaml_file:
execution_params["InferenceConfigFile"] = args.inference_yaml_file
pipeline_deploy_args.inference_config.inference_yaml_file = (
args.inference_yaml_file
)
if args.inference_model_snapshot:
execution_params["InferenceModelSnapshot"] = args.inference_model_snapshot
pipeline_deploy_args.inference_config.inference_model_snapshot = (
args.inference_model_snapshot
)
# If user specified a subpath use that, otherwise let the execution parameters determine it
if args.execution_subpath:
execution_params["ExecutionSubpath"] = args.execution_subpath
else:
execution_params["ExecutionSubpath"] = pipeline_deploy_args.get_hash_hex()

if pipeline_deploy_args.get_hash_hex() != deploy_time_hash:
new_prefix = os.path.join(
pipeline_deploy_args.task_config.output_prefix,
args.pipeline_name,
pipeline_deploy_args.get_hash_hex(),
)
warnings.warn(
"The pipeline execution arguments have been modified "
"compared to the deployment parameters. "
f"This execution will use a new unique output prefix, : {new_prefix}."
)

# If no parameters are provided, use an empty dict to use all defaults
execution = pipeline.start(
Expand Down
Loading

0 comments on commit 1e08286

Please sign in to comment.