diff --git a/examples/customized_models/HGT/hgt_nc.py b/examples/customized_models/HGT/hgt_nc.py index 7164266def..6da88e6870 100644 --- a/examples/customized_models/HGT/hgt_nc.py +++ b/examples/customized_models/HGT/hgt_nc.py @@ -393,7 +393,8 @@ def main(args): argparser.add_argument("--local_rank", type=int, help="The rank of the trainer. \ For customized models, MUST have this argument!!") - args = argparser.parse_args() + # Ignore unknown args to make script more robust to input arguments + args, _ = argparser.parse_known_args() print(args) main(args) diff --git a/examples/temporal_graph_learning/main_nc.py b/examples/temporal_graph_learning/main_nc.py index cf3f360874..20fe18889b 100644 --- a/examples/temporal_graph_learning/main_nc.py +++ b/examples/temporal_graph_learning/main_nc.py @@ -106,8 +106,8 @@ def generate_parser(): if __name__ == "__main__": arg_parser = generate_parser() - args = arg_parser.parse_args() - print(args) + # Ignore unknown args to make script more robust to input arguments + args, _ = arg_parser.parse_known_args() main(args) diff --git a/python/graphstorm/gconstruct/remap_result.py b/python/graphstorm/gconstruct/remap_result.py index ec3f97da48..2fc0ea606f 100644 --- a/python/graphstorm/gconstruct/remap_result.py +++ b/python/graphstorm/gconstruct/remap_result.py @@ -79,7 +79,14 @@ def write_data_parquet_file(data, file_prefix, col_name_map=None): A mapping from builtin column name to user defined column name. """ if col_name_map is not None: - data = {col_name_map[key]: val for key, val in data.items()} + updated_data = {} + for key, val in data.items(): + if key in col_name_map: + updated_data[col_name_map[key]] = val + else: + updated_data[key] = val + data = updated_data + output_fname = f"{file_prefix}.parquet" write_data_parquet(data, output_fname) @@ -107,7 +114,13 @@ def write_data_csv_file(data, file_prefix, delimiter=",", col_name_map=None): A mapping from builtin column name to user defined column name. """ if col_name_map is not None: - data = {col_name_map[key]: val for key, val in data.items()} + updated_data = {} + for key, val in data.items(): + if key in col_name_map: + updated_data[col_name_map[key]] = val + else: + updated_data[key] = val + data = updated_data output_fname = f"{file_prefix}.csv" csv_data = {} diff --git a/python/graphstorm/run/gsgnn_dt/distill_gnn.py b/python/graphstorm/run/gsgnn_dt/distill_gnn.py index 17c9b3becd..556ad9940e 100644 --- a/python/graphstorm/run/gsgnn_dt/distill_gnn.py +++ b/python/graphstorm/run/gsgnn_dt/distill_gnn.py @@ -92,6 +92,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - print(args) - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py index 4f402600b0..a4d6eee6a3 100644 --- a/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py +++ b/python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py @@ -103,5 +103,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser = generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py b/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py index 7ae09cab29..8a6bd6b98b 100644 --- a/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py @@ -99,5 +99,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py b/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py index 7e43ff18cb..e26d8a8118 100644 --- a/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py +++ b/python/graphstorm/run/gsgnn_ep/ep_infer_lm.py @@ -89,5 +89,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py index aedfd5c83b..661c9e14d3 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_ep.py @@ -171,5 +171,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py b/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py index 2a85a49011..39c54c995c 100644 --- a/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py +++ b/python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py @@ -148,5 +148,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py index 143fd58e06..cfd3d5fe3e 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py @@ -197,5 +197,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 92a40dc737..acc0c2fdbc 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -223,5 +223,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 50a2e97acc..0e95a89841 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -88,5 +88,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py index e196d3fd83..aa7c051b90 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py @@ -90,5 +90,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_np/gsgnn_np.py b/python/graphstorm/run/gsgnn_np/gsgnn_np.py index 4306dac1b0..2292f65955 100644 --- a/python/graphstorm/run/gsgnn_np/gsgnn_np.py +++ b/python/graphstorm/run/gsgnn_np/gsgnn_np.py @@ -182,5 +182,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/run/gsgnn_np/np_infer_gnn.py b/python/graphstorm/run/gsgnn_np/np_infer_gnn.py index 3c58f33a6b..10a84e9108 100644 --- a/python/graphstorm/run/gsgnn_np/np_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_np/np_infer_gnn.py @@ -94,5 +94,6 @@ def generate_parser(): if __name__ == '__main__': arg_parser=generate_parser() - args = arg_parser.parse_args() - main(args) + # Ignore unknown args to make script more robust to input arguments + gs_args, _ = arg_parser.parse_known_args() + main(gs_args) diff --git a/python/graphstorm/sagemaker/utils.py b/python/graphstorm/sagemaker/utils.py index a58e07b0aa..2e7795e64e 100644 --- a/python/graphstorm/sagemaker/utils.py +++ b/python/graphstorm/sagemaker/utils.py @@ -272,9 +272,10 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size, graph_path, sagemaker_session=sagemaker_session) try: logging.info("Download graph from %s to %s", - os.path.join(graph_data_s3, graph_part), + os.path.join(os.path.join(graph_data_s3, graph_part), ""), graph_part_path) - S3Downloader.download(os.path.join(graph_data_s3, graph_part), + # add tailing / to s3:/xxxx/partN + S3Downloader.download(os.path.join(os.path.join(graph_data_s3, graph_part), ""), graph_part_path, sagemaker_session=sagemaker_session) except Exception as err: # pylint: disable=broad-except logging.error("Can not download graph_data from %s, %s.", diff --git a/sagemaker/launch/launch_infer.py b/sagemaker/launch/launch_infer.py index 725d23da57..d47e588e51 100644 --- a/sagemaker/launch/launch_infer.py +++ b/sagemaker/launch/launch_infer.py @@ -92,6 +92,7 @@ def run_job(input_args, image, unknownargs): # We must handle cases like # --target-etype query,clicks,asin query,search,asin # --feat-name ntype0:feat0 ntype1:feat1 + # --column-names nid,~id emb,embedding unknow_idx = 0 while unknow_idx < len(unknownargs): print(unknownargs[unknow_idx]) diff --git a/tests/unit-tests/gconstruct/test_remap_result.py b/tests/unit-tests/gconstruct/test_remap_result.py index 1421c8f5b5..6d6c3f59af 100644 --- a/tests/unit-tests/gconstruct/test_remap_result.py +++ b/tests/unit-tests/gconstruct/test_remap_result.py @@ -278,7 +278,116 @@ def test__get_file_range(): assert start == 7 assert end == 10 +def test_write_data_parquet_file(): + data = {"emb": np.random.rand(10, 10), + "nid": np.arange(10), + "pred": np.random.rand(10, 10)} + + def check_write_content(fname, col_names): + # col_names should in order of emb, nid and pred + parq_data = read_data_parquet(fname, col_names) + assert_almost_equal(data["emb"], parq_data[col_names[0]]) + assert_equal(data["nid"], parq_data[col_names[1]]) + assert_almost_equal(data["pred"], parq_data[col_names[2]]) + + # without renaming columns + with tempfile.TemporaryDirectory() as tmpdirname: + file_prefix = os.path.join(tmpdirname, "test") + write_data_parquet_file(data, file_prefix, None) + output_fname = f"{file_prefix}.parquet" + + check_write_content(output_fname, ["emb", "nid", "pred"]) + + # rename all column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + "pred": "new_pred" + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_parquet_file(data, file_prefix, col_name_map) + output_fname = f"{file_prefix}.parquet" + + check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"]) + + # rename part of column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_parquet_file(data, file_prefix, col_name_map) + output_fname = f"{file_prefix}.parquet" + + check_write_content(output_fname, ["new_emb", "new_nid", "pred"]) + +def test_write_data_csv_file(): + data = {"emb": np.random.rand(10, 10), + "nid": np.arange(10), + "pred": np.random.rand(10, 10)} + + def check_write_content(fname, col_names): + # col_names should in order of emb, nid and pred + csv_data = pd.read_csv(fname, delimiter=",") + # emb + assert col_names[0] in csv_data + csv_emb_data = csv_data[col_names[0]].values.tolist() + csv_emb_data = [d.split(";") for d in csv_emb_data] + csv_emb_data = np.array(csv_emb_data, dtype=np.float32) + assert_almost_equal(data["emb"], csv_emb_data) + + # nid + assert col_names[1] in csv_data + csv_nid_data = csv_data[col_names[1]].values.tolist() + csv_nid_data = np.array(csv_nid_data, dtype=np.int32) + assert_equal(data["nid"], csv_nid_data) + + # pred + assert col_names[2] in csv_data + csv_pred_data = csv_data[col_names[2]].values.tolist() + csv_pred_data = [d.split(";") for d in csv_pred_data] + csv_pred_data = np.array(csv_pred_data, dtype=np.float32) + assert_almost_equal(data["pred"], csv_pred_data) + + # without renaming columns + with tempfile.TemporaryDirectory() as tmpdirname: + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=None) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["emb", "nid", "pred"]) + + # rename all column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + "pred": "new_pred" + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=col_name_map) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"]) + + # rename part of column names + with tempfile.TemporaryDirectory() as tmpdirname: + col_name_map = { + "emb": "new_emb", + "nid": "new_nid", + } + file_prefix = os.path.join(tmpdirname, "test") + write_data_csv_file(data, file_prefix, col_name_map=col_name_map) + output_fname = f"{file_prefix}.csv" + + check_write_content(output_fname, ["new_emb", "new_nid", "pred"]) + + if __name__ == '__main__': + test_write_data_csv_file() + test_write_data_parquet_file() test__get_file_range() test_worker_remap_edge_pred() test_worker_remap_node_data("pred") diff --git a/tests/unit-tests/test_model_save_load.py b/tests/unit-tests/test_model_save_load.py index 7887c2c304..6be8716999 100644 --- a/tests/unit-tests/test_model_save_load.py +++ b/tests/unit-tests/test_model_save_load.py @@ -146,7 +146,7 @@ def check_sparse_emb(mock_get_world_size, mock_get_rank): for i in range(infer_world_size): mock_get_rank.side_effect = [i] * 2 - mock_get_world_size.side_effect = [train_world_size] * 2 + mock_get_world_size.side_effect = [infer_world_size] * 2 load_sparse_embeds(model_path, embed_layer) load_sparse_embs = \ {ntype: sparse_emb._tensor[th.arange(embed_layer.g.number_of_nodes(ntype))] \