Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor graphstorm.sagemaker logging printing #650

Merged
merged 9 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,8 @@ def _load_id_mapping(self, g, ntype, id_mappings):
# Shuffled node ID: 0, 1, 2
id_mapping = id_mappings[ntype] if isinstance(id_mappings, dict) else id_mappings
assert id_mapping.shape[0] == num_nodes, \
"id mapping should have the same size of num_nodes"
"Id mapping should have the same size of num_nodes." \
f"Expect {id_mapping.shape[0]}, but get {num_nodes}"
# Save ID mapping into dist tensor
id_mapping_info[th.arange(num_nodes)] = id_mapping
barrier()
Expand Down
29 changes: 19 additions & 10 deletions python/graphstorm/sagemaker/sagemaker_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
# Install additional requirements
import os
import logging
import socket
import time
import json
Expand Down Expand Up @@ -117,13 +118,14 @@ def launch_infer_task(task_type, num_gpus, graph_config,
launch_cmd += ["--cf", f"{yaml_path}",
"--restore-model-path", f"{load_model_path}",
"--save-embed-path", f"{save_emb_path}"] + extra_args
logging.debug("Launch inference %s", launch_cmd)

def run(launch_cmd, state_q):
try:
subprocess.check_call(launch_cmd, shell=False)
state_q.put(0)
except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
logging.error("Called process error %s", err)
state_q.put(err.returncode)
except Exception: # pylint: disable=broad-except
state_q.put(-1)
Expand Down Expand Up @@ -174,8 +176,8 @@ def run_infer(args, unknownargs):
# start the ssh server
subprocess.run(["service", "ssh", "start"], check=True)

print(f"Know args {args}")
print(f"Unknow args {unknownargs}")
logging.info("Known args %s", args)
logging.info("Unknown args %s", unknownargs)

train_env = json.loads(args.sm_dist_env)
hosts = train_env['hosts']
Expand All @@ -184,9 +186,15 @@ def run_infer(args, unknownargs):
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = hosts.index(current_host)

# NOTE: Ensure no logging has been done before setting logging configuration
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format=f'{current_host}: %(asctime)s - %(levelname)s - %(message)s',
force=True)

try:
for host in hosts:
print(f"The {host} IP is {socket.gethostbyname(host)}")
logging.info("The %s IP is %s", host, {socket.gethostbyname(host)})
except:
raise RuntimeError(f"Can not get host name of {hosts}")

Expand All @@ -210,9 +218,9 @@ def run_infer(args, unknownargs):
sock.connect((master_addr, 12345))
break
except: # pylint: disable=bare-except
print(f"Try to connect {master_addr}")
logging.info("Try to connect %s", master_addr)
time.sleep(10)
print("Connected")
logging.info("Connected")

# write ip list info into disk
ip_list_path = os.path.join(data_path, 'ip_list.txt')
Expand Down Expand Up @@ -247,7 +255,8 @@ def run_infer(args, unknownargs):

# Download Saved model
download_model(model_artifact_s3, model_path, sagemaker_session)
print(f"{model_path} {os.listdir(model_path)}")
logging.info("Successfully downloaded the model into %s.\n The model files are: %s.",
model_path, os.listdir(model_path))

err_code = 0
if host_rank == 0:
Expand Down Expand Up @@ -281,7 +290,7 @@ def run_infer(args, unknownargs):
err_code = -1

terminate_workers(client_list, world_size, task_end)
print("Master End")
logging.info("Master End")
if err_code != -1:
upload_embs(output_emb_s3, emb_path, sagemaker_session)
# clean embs, so SageMaker does not need to upload embs again
Expand All @@ -295,12 +304,12 @@ def run_infer(args, unknownargs):
upload_embs(output_emb_s3, emb_path, sagemaker_session)
# clean embs, so SageMaker does not need to upload embs again
remove_embs(emb_path)
print("Worker End")
logging.info("Worker End")

sock.close()
if err_code != 0:
# Report an error
print("Task failed")
logging.error("Task failed")
sys.exit(-1)

if args.output_prediction_s3 is not None:
Expand Down
32 changes: 19 additions & 13 deletions python/graphstorm/sagemaker/sagemaker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
# Install additional requirements
import os
import logging
import socket
import time
import json
Expand Down Expand Up @@ -108,15 +109,14 @@ def launch_train_task(task_type, num_gpus, graph_config,
launch_cmd += ["--restore-model-path", f"{restore_model_path}"] \
if restore_model_path is not None else []
launch_cmd += extra_args

print(launch_cmd)
logging.debug("Launch training %s", launch_cmd)

def run(launch_cmd, state_q):
try:
subprocess.check_call(launch_cmd, shell=False)
state_q.put(0)
except subprocess.CalledProcessError as err:
print(f"Called process error {err}")
logging.error("Called process error %s", err)
state_q.put(err.returncode)
except Exception: # pylint: disable=broad-except
state_q.put(-1)
Expand Down Expand Up @@ -167,8 +167,8 @@ def run_train(args, unknownargs):
# start the ssh server
subprocess.run(["service", "ssh", "start"], check=True)

print(f"Know args {args}")
print(f"Unknow args {unknownargs}")
logging.info("Known args %s", args)
logging.info("Unknown args %s", unknownargs)

save_model_path = os.path.join(output_path, "model_checkpoint")

Expand All @@ -179,9 +179,15 @@ def run_train(args, unknownargs):
os.environ['WORLD_SIZE'] = str(world_size)
host_rank = hosts.index(current_host)

# NOTE: Ensure no logging has been done before setting logging configuration
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format=f'{current_host}: %(asctime)s - %(levelname)s - %(message)s',
force=True)

try:
for host in hosts:
print(f"The {host} IP is {socket.gethostbyname(host)}")
logging.info("The %s IP is %s", host, socket.gethostbyname(host))
except:
raise RuntimeError(f"Can not get host name of {hosts}")

Expand All @@ -205,9 +211,9 @@ def run_train(args, unknownargs):
sock.connect((master_addr, 12345))
break
except: # pylint: disable=bare-except
print(f"Try to connect {master_addr}")
logging.info("Try to connect %s", master_addr)
time.sleep(10)
print("Connected")
logging.info("Connected")

# write ip list info into disk
ip_list_path = os.path.join(data_path, 'ip_list.txt')
Expand All @@ -232,8 +238,8 @@ def run_train(args, unknownargs):
if model_checkpoint_s3 is not None:
# Download Saved model checkpoint to resume
download_model(model_checkpoint_s3, restore_model_path, sagemaker_session)
print(f"{restore_model_path} {os.listdir(restore_model_path)}")

logging.info("Successfully downloaded the model into %s.\n The model files are: %s.",
restore_model_path, os.listdir(restore_model_path))

err_code = 0
if host_rank == 0:
Expand Down Expand Up @@ -265,18 +271,18 @@ def run_train(args, unknownargs):
print(e)
err_code = -1
terminate_workers(client_list, world_size, task_end)
print("Master End")
logging.info("Master End")
else:
barrier(sock)
# Block util training finished
# Listen to end command
wait_for_exit(sock)
print("Worker End")
logging.info("Worker End")

sock.close()
if err_code != 0:
# Report an error
print("Task failed")
logging.error("Task failed")
sys.exit(-1)

# If there are saved models
Expand Down
41 changes: 27 additions & 14 deletions python/graphstorm/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def download_yaml_config(yaml_s3, local_path, sagemaker_session):
try:
S3Downloader.download(yaml_s3, local_path,
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
raise RuntimeError(f"Fail to download yaml file {yaml_s3}")
except Exception as err: # pylint: disable=broad-except
raise RuntimeError(f"Fail to download yaml file {yaml_s3}: {err}")

return yaml_path

Expand All @@ -206,9 +206,10 @@ def download_model(model_artifact_s3, model_path, sagemaker_session):
try:
S3Downloader.download(model_artifact_s3,
model_path, sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
except Exception as err: # pylint: disable=broad-except
raise RuntimeError("Can not download saved model artifact" \
f"model.bin from {model_artifact_s3}.")
f"model.bin from {model_artifact_s3}." \
f"{err}")

def download_graph(graph_data_s3, graph_name, part_id, world_size,
local_path, sagemaker_session):
Expand Down Expand Up @@ -273,19 +274,29 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
S3Downloader.download(os.path.join(graph_data_s3, graph_config),
graph_path, sagemaker_session=sagemaker_session)
try:
logging.info("Download graph from %s to %s",
os.path.join(graph_data_s3, graph_part),
graph_part_path)
S3Downloader.download(os.path.join(graph_data_s3, graph_part),
graph_part_path, sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print(f"Can not download graph_data from {graph_data_s3}.")
raise RuntimeError(f"Can not download graph_data from {graph_data_s3}.")
except Exception as err: # pylint: disable=broad-except
logging.error("Can not download graph_data from %s, %s.",
graph_data_s3, str(err))
raise RuntimeError(f"Can not download graph_data from {graph_data_s3}, {err}.")
thvasilo marked this conversation as resolved.
Show resolved Hide resolved

node_id_mapping = "node_mapping.pt"
# Try to download node id mapping file if any
try:
logging.info("Download graph id mapping from %s to %s",
os.path.join(graph_data_s3, node_id_mapping),
graph_path)
S3Downloader.download(os.path.join(graph_data_s3, node_id_mapping),
graph_path, sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print("node id mapping file does not exist")
logging.warning("Node id mapping file does not exist."
"If you are running GraphStorm on a graph with "
"more than 1 partition, it is recommended to provide "
"the node id mapping file created by gconstruct or gsprocessing.")

if part_id == 0:
# It is possible that id mappings are generated by
Expand All @@ -309,12 +320,14 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
id_map_files = [file for file in files if file.endswith("id_remap.parquet")]
for file in id_map_files:
try:
logging.info("Download graph remap from %s to %s",
file, graph_path)
S3Downloader.download(file, graph_path,
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print(f"node id remap file {file} does not exist")
logging.warning("node id remap file %s does not exist", file)

print(f"Finish download graph data from {graph_data_s3}")
logging.info("Finished downloading graph data from %s", graph_data_s3)
return os.path.join(graph_path, graph_config)


Expand All @@ -333,9 +346,9 @@ def upload_data_to_s3(s3_path, data_path, sagemaker_session):
try:
ret = S3Uploader.upload(data_path, s3_path,
sagemaker_session=sagemaker_session)
except Exception: # pylint: disable=broad-except
print(f"Can not upload data into {s3_path}")
raise RuntimeError(f"Can not upload data into {s3_path}")
except Exception as err: # pylint: disable=broad-except
logging.error("Can not upload data into %s", s3_path)
raise RuntimeError(f"Can not upload data into {s3_path}. {err}")
thvasilo marked this conversation as resolved.
Show resolved Hide resolved
return ret

def upload_model_artifacts(model_s3_path, model_path, sagemaker_session):
Expand All @@ -354,7 +367,7 @@ def upload_model_artifacts(model_s3_path, model_path, sagemaker_session):
sagemaker_session: sagemaker.session.Session
sagemaker_session to run download
"""
print(f"Upload model artifacts to {model_s3_path}")
logging.info("Uploading model artifacts to %s", model_s3_path)
# Rank0 will upload both dense models and learnable embeddings owned by Rank0.
# Other ranks will only upload learnable embeddings owned by themselves.
return upload_data_to_s3(model_s3_path, model_path, sagemaker_session)
Expand Down
9 changes: 7 additions & 2 deletions sagemaker/launch/launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def run_job(input_args, image, unknownargs):
output_predict_s3_path = input_args.output_prediction_s3 # S3 location to save prediction results
model_artifact_s3 = input_args.model_artifact_s3 # S3 location of saved model artifacts
output_chunk_size = input_args.output_chunk_size # Number of rows per chunked prediction result or node embedding file.
log_level = input_args.log_level # SageMaker runner logging level

boto_session = boto3.session.Session(region_name=region)
sagemaker_client = boto_session.client(service_name="sagemaker", region_name=region)
Expand All @@ -76,7 +77,8 @@ def run_job(input_args, image, unknownargs):
"infer-yaml-s3": infer_yaml_s3,
"output-emb-s3": output_emb_s3_path,
"model-artifact-s3": model_artifact_s3,
"output-chunk-size": output_chunk_size}
"output-chunk-size": output_chunk_size,
"log-level": log_level}
else:
params = {"task-type": task_type,
"graph-name": graph_name,
Expand All @@ -85,7 +87,8 @@ def run_job(input_args, image, unknownargs):
"output-emb-s3": output_emb_s3_path,
"output-prediction-s3": output_predict_s3_path,
"model-artifact-s3": model_artifact_s3,
"output-chunk-size": output_chunk_size}
"output-chunk-size": output_chunk_size,
"log-level": log_level}
# We must handle cases like
# --target-etype query,clicks,asin query,search,asin
# --feat-name ntype0:feat0 ntype1:feat1
Expand Down Expand Up @@ -170,6 +173,8 @@ def get_inference_parser():
help="Relative path to the trained model under <model_artifact_s3>."
"There can be multiple model checkpoints under"
"<model_artifact_s3>, this argument is used to choose one.")
inference_args.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

return parser

Expand Down
6 changes: 5 additions & 1 deletion sagemaker/launch/launch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_job(input_args, image, unknowargs):
model_artifact_s3 = input_args.model_artifact_s3 # Where to store model artifacts
model_checkpoint_to_load = input_args.model_checkpoint_to_load # S3 location of a saved model.
custom_script = input_args.custom_script # custom_script if any
log_level = input_args.log_level # SageMaker runner logging level

boto_session = boto3.session.Session(region_name=region)
sagemaker_client = boto_session.client(service_name="sagemaker", region_name=region)
Expand All @@ -66,7 +67,8 @@ def run_job(input_args, image, unknowargs):
"graph-name": graph_name,
"graph-data-s3": graph_data_s3,
"train-yaml-s3": train_yaml_s3,
"model-artifact-s3": model_artifact_s3}
"model-artifact-s3": model_artifact_s3,
"log-level": log_level}
if custom_script is not None:
params["custom-script"] = custom_script
if model_checkpoint_to_load is not None:
Expand Down Expand Up @@ -143,6 +145,8 @@ def get_train_parser():
training_args.add_argument("--custom-script", type=str, default=None,
help="Custom training script provided by a customer to run customer training logic. \
Please provide the path of the script within the docker image")
training_args.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

return parser

Expand Down
2 changes: 2 additions & 0 deletions sagemaker/run/infer_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def parse_train_args():
Please provide the path of the script within the docker image")
parser.add_argument("--output-chunk-size", type=int, default=100000,
help="Number of rows per chunked prediction result or node embedding file.")
parser.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

# following arguments are required to launch a distributed GraphStorm training task
parser.add_argument('--data-path', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
Expand Down
2 changes: 2 additions & 0 deletions sagemaker/run/train_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def parse_train_args():
parser.add_argument("--custom-script", type=str, default=None,
help="Custom training script provided by a customer to run customer training logic. \
Please provide the path of the script within the docker image")
parser.add_argument('--log-level', default='INFO',
type=str, choices=['DEBUG', 'INFO', 'WARNING', 'CRITICAL', 'FATAL'])

# following arguments are required to launch a distributed GraphStorm training task
parser.add_argument('--data-path', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
Expand Down
Loading