Skip to content

Commit

Permalink
Handle empty raw node mappings in launch
Browse files Browse the repository at this point in the history
  • Loading branch information
thvasilo committed Jan 8, 2024
1 parent dda1529 commit 9c7139b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion sagemaker/launch/launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ def run_job(input_args, image, unknownargs):
"graph-name": graph_name,
"infer-yaml-s3": infer_yaml_s3,
"model-artifact-s3": model_artifact_s3,
"raw-node-mappings-s3": input_args.raw_node_mappings_s3,
"output-chunk-size": output_chunk_size,
"output-emb-s3": output_emb_s3_path,
"task-type": task_type,
}
# In Link Prediction, no prediction outputs
if task_type not in ["link_prediction", "compute_emb"]:
params["output-prediction-s3"] = output_predict_s3_path
# If no raw mapping files are provided, remapping is skipped
if input_args.raw_node_mappings_s3 is not None:
params["raw-node-mappings-s3"] = input_args.raw_node_mappings_s3
# We must handle cases like
# --target-etype query,clicks,asin query,search,asin
# --feat-name ntype0:feat0 ntype1:feat1
Expand Down Expand Up @@ -150,6 +152,7 @@ def get_inference_parser():
required=True)
inference_args.add_argument("--raw-node-mappings-s3", type=str,
help="S3 location to load the node id mappings from",
default=None,
required=False)
inference_args.add_argument("--output-emb-s3", type=str,
help="S3 location to store GraphStorm generated node embeddings.",
Expand Down
2 changes: 1 addition & 1 deletion sagemaker/run/infer_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse_inference_args():
parser.add_argument("--model-artifact-s3", type=str,
help="S3 bucket to load the saved model artifacts")
parser.add_argument("--raw-node-mappings-s3", type=str, required=False,
default=None, help="S3 location where the node mappings exist.")
default=None, help="S3 location where the original (str to int) node mappings exist.")
parser.add_argument("--custom-script", type=str, default=None,
help="Custom training script provided by a customer to run customer training logic. \
Please provide the path of the script within the docker image")
Expand Down

0 comments on commit 9c7139b

Please sign in to comment.