Skip to content

Commit

Permalink
Refactor graphstorm.sagemaker logging printing (#650)
Browse files Browse the repository at this point in the history
We refactor graphstorm.sagemaker module to use python logging package to
print information.
The default logging level is INFO.

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Nov 29, 2023
1 parent 1fa6672 commit 60fa3df
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 41 deletions.
3 changes: 2 additions & 1 deletion python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,8 @@ def _load_id_mapping(self, g, ntype, id_mappings):
# Shuffled node ID: 0, 1, 2
id_mapping = id_mappings[ntype] if isinstance(id_mappings, dict) else id_mappings
assert id_mapping.shape[0] == num_nodes, \
"id mapping should have the same size of num_nodes"
"Id mapping should have the same size of num_nodes." \
f"Expect {id_mapping.shape[0]}, but get {num_nodes}"
# Save ID mapping into dist tensor
id_mapping_info[th.arange(num_nodes)] = id_mapping
barrier()
Expand Down
29 changes: 19 additions & 10 deletions python/graphstorm/sagemaker/sagemaker_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
# Install additional requirements
import os
import logging
import socket
import time
import json
Expand Down Expand Up @@ -117,13 +118,14 @@ def launch_infer_task(task_type, num_gpus, graph_config,
launch_cmd += ["--cf", f"{yaml_path}",
"--restore-model-path", f"{load_model_path}",
"--save-embed-path", f"{save_emb_path}"] + extra_args
logging.debug("Launch inference %s", launch_cmd)

def run(launch_cmd, state_q):
try:
subprocess.check_call(launch_cmd, shell=False)
state_q.put(0)
except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
logging.error("Called process error %s", err)
state_q.put(err.returncode)
except Exception: # pylint: disable=broad-except
state_q.put(-1)
Expand Down Expand Up @@ -174,8 +176,8 @@ def run_infer(args, unknownargs):
# start the ssh server
subprocess.run(["service", "ssh", "start"], check=True)

print(f"Know args {args}")
print(f"Unknow args {unknownargs}")
logging.info("Known args %s", args)
logging.info("Unknown args %s", unknownargs)

train_env = json.loads(args.sm_dist_env)
hosts = train_env['hosts']
Expand All @@ -184,9 +186,15 @@ def run_infer(args, unknownargs):
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = hosts.index(current_host)

# NOTE: Ensure no logging has been done before setting logging configuration
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format=f'{current_host}: %(asctime)s - %(levelname)s - %(message)s',
force=True)

try:
for host in hosts:
print(f"The {host} IP is {socket.gethostbyname(host)}")
logging.info("The %s IP is %s", host, {socket.gethostbyname(host)})
except:
raise RuntimeError(f"Can not get host name of {hosts}")

Expand All @@ -210,9 +218,9 @@ def run_infer(args, unknownargs):
sock.connect((master_addr, 12345))
break
except: # pylint: disable=bare-except
print(f"Try to connect {master_addr}")
logging.info("Try to connect %s", master_addr)
time.sleep(10)
print("Connected")
logging.info("Connected")

# write ip list info into disk
ip_list_path = os.path.join(data_path, 'ip_list.txt')
Expand Down Expand Up @@ -247,7 +255,8 @@ def run_infer(args, unknownargs):

# Download Saved model
download_model(model_artifact_s3, model_path, sagemaker_session)
print(f"{model_path} {os.listdir(model_path)}")
logging.info("Successfully downloaded the model into %s.\n The model files are: %s.",
model_path, os.listdir(model_path))

err_code = 0
if host_rank == 0:
Expand Down Expand Up @@ -281,7 +290,7 @@ def run_infer(args, unknownargs):
err_code = -1

terminate_workers(client_list, world_size, task_end)
print("Master End")
logging.info("Master End")
if err_code != -1:
upload_embs(output_emb_s3, emb_path, sagemaker_session)
# clean embs, so SageMaker does not need to upload embs again
Expand All @@ -295,12 +304,12 @@ def run_infer(args, unknownargs):
upload_embs(output_emb_s3, emb_path, sagemaker_session)
# clean embs, so SageMaker does not need to upload embs again
remove_embs(emb_path)
print("Worker End")
logging.info("Worker End")

sock.close()
if err_code != 0:
# Report an error
print("Task failed")
logging.error("Task failed")
sys.exit(-1)

if args.output_prediction_s3 is not None:
Expand Down
32 changes: 19 additions & 13 deletions python/graphstorm/sagemaker/sagemaker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
# Install additional requirements
import os
import logging
import socket
import time
import json
Expand Down Expand Up @@ -108,15 +109,14 @@ def launch_train_task(task_type, num_gpus, graph_config,
launch_cmd += ["--restore-model-path", f"{restore_model_path}"] \
if restore_model_path is not None else []
launch_cmd += extra_args

print(launch_cmd)
logging.debug("Launch training %s", launch_cmd)

def run(launch_cmd, state_q):
try:
subprocess.check_call(launch_cmd, shell=False)
state_q.put(0)
except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
logging.error("Called process error %s", err)
state_q.put(err.returncode)
except Exception: # pylint: disable=broad-except
state_q.put(-1)
Expand Down Expand Up @@ -167,8 +167,8 @@ def run_train(args, unknownargs):
# start the ssh server
subprocess.run(["service", "ssh", "start"], check=True)

print(f"Know args {args}")
print(f"Unknow args {unknownargs}")
logging.info("Known args %s", args)
logging.info("Unknown args %s", unknownargs)

save_model_path = os.path.join(output_path, "model_checkpoint")

Expand All @@ -179,9 +179,15 @@ def run_train(args, unknownargs):
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = hosts.index(current_host)

# NOTE: Ensure no logging has been done before setting logging configuration
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format=f'{current_host}: %(asctime)s - %(levelname)s - %(message)s',
force=True)

try:
for host in hosts:
print(f"The {host} IP is {socket.gethostbyname(host)}")
logging.info("The %s IP is %s", host, socket.gethostbyname(host))
except:
raise RuntimeError(f"Can not get host name of {hosts}")

Expand All @@ -205,9 +211,9 @@ def run_train(args, unknownargs):
sock.connect((master_addr, 12345))
break
except: # pylint: disable=bare-except
print(f"Try to connect {master_addr}")
logging.info("Try to connect %s", master_addr)
time.sleep(10)
print("Connected")
logging.info("Connected")

# write ip list info into disk
ip_list_path = os.path.join(data_path, 'ip_list.txt')
Expand All @@ -232,8 +238,8 @@ def run_train(args, unknownargs):
if model_checkpoint_s3 is not None:
# Download Saved model checkpoint to resume
download_model(model_checkpoint_s3, restore_model_path, sagemaker_session)
print(f"{restore_model_path} {os.listdir(restore_model_path)}")

logging.info("Successfully downloaded the model into %s.\n The model files are: %s.",
restore_model_path, os.listdir(restore_model_path))

err_code = 0
if host_rank == 0:
Expand Down Expand Up @@ -265,18 +271,18 @@ def run_train(args, unknownargs):
print(e)
err_code = -1
terminate_workers(client_list, world_size, task_end)
print("Master End")
logging.info("Master End")
else:
barrier(sock)
# Block util training finished
# Listen to end command
wait_for_exit(sock)
print("Worker End")
logging.info("Worker End")

sock.close()
if err_code != 0:
# Report an error
print("Task failed")
logging.error("Task failed")
sys.exit(-1)

# If there are saved models
Expand Down
41 changes: 27 additions & 14 deletions python/graphstorm/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def download_yaml_config(yaml_s3, local_path, sagemaker_session):
try:
S3Downloader.download(yaml_s3, local_path,
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
raise RuntimeError(f"Fail to download yaml file {yaml_s3}")
except Exception as err: # pylint: disable=broad-except
raise RuntimeError(f"Fail to download yaml file {yaml_s3}: {err}")

return yaml_path

Expand All @@ -206,9 +206,10 @@ def download_model(model_artifact_s3, model_path, sagemaker_session):
try:
S3Downloader.download(model_artifact_s3,
model_path, sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
except Exception as err: # pylint: disable=broad-except
raise RuntimeError("Can not download saved model artifact" \
f"model.bin from {model_artifact_s3}.")
f"model.bin from {model_artifact_s3}." \
f"{err}")

def download_graph(graph_data_s3, graph_name, part_id, world_size,
local_path, sagemaker_session):
Expand Down Expand Up @@ -273,19 +274,29 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
S3Downloader.download(os.path.join(graph_data_s3, graph_config),
graph_path, sagemaker_session=sagemaker_session)
try:
logging.info("Download graph from %s to %s",
os.path.join(graph_data_s3, graph_part),
graph_part_path)
S3Downloader.download(os.path.join(graph_data_s3, graph_part),
graph_part_path, sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print(f"Can not download graph_data from {graph_data_s3}.")
raise RuntimeError(f"Can not download graph_data from {graph_data_s3}.")
except Exception as err: # pylint: disable=broad-except
logging.error("Can not download graph_data from %s, %s.",
graph_data_s3, str(err))
raise RuntimeError(f"Can not download graph_data from {graph_data_s3}, {err}.")

node_id_mapping = "node_mapping.pt"
# Try to download node id mapping file if any
try:
logging.info("Download graph id mapping from %s to %s",
os.path.join(graph_data_s3, node_id_mapping),
graph_path)
S3Downloader.download(os.path.join(graph_data_s3, node_id_mapping),
graph_path, sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print("node id mapping file does not exist")
logging.warning("Node id mapping file does not exist."
"If you are running GraphStorm on a graph with "
"more than 1 partition, it is recommended to provide "
"the node id mapping file created by gconstruct or gsprocessing.")

if part_id == 0:
# It is possible that id mappings are generated by
Expand All @@ -309,12 +320,14 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
id_map_files = [file for file in files if file.endswith("id_remap.parquet")]
for file in id_map_files:
try:
logging.info("Download graph remap from %s to %s",
file, graph_path)
S3Downloader.download(file, graph_path,
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print(f"node id remap file {file} does not exist")
logging.warning("node id remap file %s does not exist", file)

print(f"Finish download graph data from {graph_data_s3}")
logging.info("Finished downloading graph data from %s", graph_data_s3)
return os.path.join(graph_path, graph_config)


Expand All @@ -333,9 +346,9 @@ def upload_data_to_s3(s3_path, data_path, sagemaker_session):
try:
ret = S3Uploader.upload(data_path, s3_path,
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print(f"Can not upload data into {s3_path}")
raise RuntimeError(f"Can not upload data into {s3_path}")
except Exception as err: # pylint: disable=broad-except
logging.error("Can not upload data into %s", s3_path)
raise RuntimeError(f"Can not upload data into {s3_path}. {err}")
return ret

def upload_model_artifacts(model_s3_path, model_path, sagemaker_session):
Expand All @@ -354,7 +367,7 @@ def upload_model_artifacts(model_s3_path, model_path, sagemaker_session):
sagemaker_session: sagemaker.session.Session
sagemaker_session to run download
"""
print(f"Upload model artifacts to {model_s3_path}")
logging.info("Uploading model artifacts to %s", model_s3_path)
# Rank0 will upload both dense models and learnable embeddings owned by Rank0.
# Other ranks will only upload learnable embeddings owned by themselves.
return upload_data_to_s3(model_s3_path, model_path, sagemaker_session)
Expand Down
9 changes: 7 additions & 2 deletions sagemaker/launch/launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def run_job(input_args, image, unknownargs):
output_predict_s3_path = input_args.output_prediction_s3 # S3 location to save prediction results
model_artifact_s3 = input_args.model_artifact_s3 # S3 location of saved model artifacts
output_chunk_size = input_args.output_chunk_size # Number of rows per chunked prediction result or node embedding file.
log_level = input_args.log_level # SageMaker runner logging level

boto_session = boto3.session.Session(region_name=region)
sagemaker_client = boto_session.client(service_name="sagemaker", region_name=region)
Expand All @@ -76,7 +77,8 @@ def run_job(input_args, image, unknownargs):
"infer-yaml-s3": infer_yaml_s3,
"output-emb-s3": output_emb_s3_path,
"model-artifact-s3": model_artifact_s3,
"output-chunk-size": output_chunk_size}
"output-chunk-size": output_chunk_size,
"log-level": log_level}
else:
params = {"task-type": task_type,
"graph-name": graph_name,
Expand All @@ -85,7 +87,8 @@ def run_job(input_args, image, unknownargs):
"output-emb-s3": output_emb_s3_path,
"output-prediction-s3": output_predict_s3_path,
"model-artifact-s3": model_artifact_s3,
"output-chunk-size": output_chunk_size}
"output-chunk-size": output_chunk_size,
"log-level": log_level}
# We must handle cases like
# --target-etype query,clicks,asin query,search,asin
# --feat-name ntype0:feat0 ntype1:feat1
Expand Down Expand Up @@ -170,6 +173,8 @@ def get_inference_parser():
help="Relative path to the trained model under <model_artifact_s3>."
"There can be multiple model checkpoints under"
"<model_artifact_s3>, this argument is used to choose one.")
inference_args.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

return parser

Expand Down
6 changes: 5 additions & 1 deletion sagemaker/launch/launch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_job(input_args, image, unknowargs):
model_artifact_s3 = input_args.model_artifact_s3 # Where to store model artifacts
model_checkpoint_to_load = input_args.model_checkpoint_to_load # S3 location of a saved model.
custom_script = input_args.custom_script # custom_script if any
log_level = input_args.log_level # SageMaker runner logging level

boto_session = boto3.session.Session(region_name=region)
sagemaker_client = boto_session.client(service_name="sagemaker", region_name=region)
Expand All @@ -66,7 +67,8 @@ def run_job(input_args, image, unknowargs):
"graph-name": graph_name,
"graph-data-s3": graph_data_s3,
"train-yaml-s3": train_yaml_s3,
"model-artifact-s3": model_artifact_s3}
"model-artifact-s3": model_artifact_s3,
"log-level": log_level}
if custom_script is not None:
params["custom-script"] = custom_script
if model_checkpoint_to_load is not None:
Expand Down Expand Up @@ -143,6 +145,8 @@ def get_train_parser():
training_args.add_argument("--custom-script", type=str, default=None,
help="Custom training script provided by a customer to run customer training logic. \
Please provide the path of the script within the docker image")
training_args.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

return parser

Expand Down
2 changes: 2 additions & 0 deletions sagemaker/run/infer_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def parse_train_args():
Please provide the path of the script within the docker image")
parser.add_argument("--output-chunk-size", type=int, default=100000,
help="Number of rows per chunked prediction result or node embedding file.")
parser.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

# following arguments are required to launch a distributed GraphStorm training task
parser.add_argument('--data-path', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
Expand Down
2 changes: 2 additions & 0 deletions sagemaker/run/train_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def parse_train_args():
parser.add_argument("--custom-script", type=str, default=None,
help="Custom training script provided by a customer to run customer training logic. \
Please provide the path of the script within the docker image")
parser.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

# following arguments are required to launch a distributed GraphStorm training task
parser.add_argument('--data-path', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
Expand Down

0 comments on commit 60fa3df

Please sign in to comment.