Skip to content

Commit

Permalink
Merge branch 'awslabs:main' into llmgnn
Browse files Browse the repository at this point in the history
  • Loading branch information
GentleZhu authored Dec 12, 2023
2 parents 6999c02 + 5e7cbb8 commit 93ece35
Show file tree
Hide file tree
Showing 19 changed files with 169 additions and 33 deletions.
3 changes: 2 additions & 1 deletion examples/customized_models/HGT/hgt_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def main(args):
argparser.add_argument("--local_rank", type=int,
help="The rank of the trainer. \
For customized models, MUST have this argument!!")
args = argparser.parse_args()

# Ignore unknown args to make script more robust to input arguments
args, _ = argparser.parse_known_args()
print(args)
main(args)
4 changes: 2 additions & 2 deletions examples/temporal_graph_learning/main_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def generate_parser():
if __name__ == "__main__":
arg_parser = generate_parser()

args = arg_parser.parse_args()
print(args)
# Ignore unknown args to make script more robust to input arguments
args, _ = arg_parser.parse_known_args()
main(args)


17 changes: 15 additions & 2 deletions python/graphstorm/gconstruct/remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def write_data_parquet_file(data, file_prefix, col_name_map=None):
A mapping from builtin column name to user defined column name.
"""
if col_name_map is not None:
data = {col_name_map[key]: val for key, val in data.items()}
updated_data = {}
for key, val in data.items():
if key in col_name_map:
updated_data[col_name_map[key]] = val
else:
updated_data[key] = val
data = updated_data

output_fname = f"{file_prefix}.parquet"
write_data_parquet(data, output_fname)

Expand Down Expand Up @@ -107,7 +114,13 @@ def write_data_csv_file(data, file_prefix, delimiter=",", col_name_map=None):
A mapping from builtin column name to user defined column name.
"""
if col_name_map is not None:
data = {col_name_map[key]: val for key, val in data.items()}
updated_data = {}
for key, val in data.items():
if key in col_name_map:
updated_data[col_name_map[key]] = val
else:
updated_data[key] = val
data = updated_data

output_fname = f"{file_prefix}.csv"
csv_data = {}
Expand Down
6 changes: 3 additions & 3 deletions python/graphstorm/run/gsgnn_dt/distill_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
print(args)
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser = generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/ep_infer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_np/gsgnn_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,5 +182,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_np/np_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,10 @@ def download_graph(graph_data_s3, graph_name, part_id, world_size,
graph_path, sagemaker_session=sagemaker_session)
try:
logging.info("Download graph from %s to %s",
os.path.join(graph_data_s3, graph_part),
os.path.join(os.path.join(graph_data_s3, graph_part), ""),
graph_part_path)
S3Downloader.download(os.path.join(graph_data_s3, graph_part),
# add tailing / to s3:/xxxx/partN
S3Downloader.download(os.path.join(os.path.join(graph_data_s3, graph_part), ""),
graph_part_path, sagemaker_session=sagemaker_session)
except Exception as err: # pylint: disable=broad-except
logging.error("Can not download graph_data from %s, %s.",
Expand Down
1 change: 1 addition & 0 deletions sagemaker/launch/launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def run_job(input_args, image, unknownargs):
# We must handle cases like
# --target-etype query,clicks,asin query,search,asin
# --feat-name ntype0:feat0 ntype1:feat1
# --column-names nid,~id emb,embedding
unknow_idx = 0
while unknow_idx < len(unknownargs):
print(unknownargs[unknow_idx])
Expand Down
109 changes: 109 additions & 0 deletions tests/unit-tests/gconstruct/test_remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,116 @@ def test__get_file_range():
assert start == 7
assert end == 10

def test_write_data_parquet_file():
data = {"emb": np.random.rand(10, 10),
"nid": np.arange(10),
"pred": np.random.rand(10, 10)}

def check_write_content(fname, col_names):
# col_names should in order of emb, nid and pred
parq_data = read_data_parquet(fname, col_names)
assert_almost_equal(data["emb"], parq_data[col_names[0]])
assert_equal(data["nid"], parq_data[col_names[1]])
assert_almost_equal(data["pred"], parq_data[col_names[2]])

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, None)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, col_name_map)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_parquet_file(data, file_prefix, col_name_map)
output_fname = f"{file_prefix}.parquet"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])

def test_write_data_csv_file():
data = {"emb": np.random.rand(10, 10),
"nid": np.arange(10),
"pred": np.random.rand(10, 10)}

def check_write_content(fname, col_names):
# col_names should in order of emb, nid and pred
csv_data = pd.read_csv(fname, delimiter=",")
# emb
assert col_names[0] in csv_data
csv_emb_data = csv_data[col_names[0]].values.tolist()
csv_emb_data = [d.split(";") for d in csv_emb_data]
csv_emb_data = np.array(csv_emb_data, dtype=np.float32)
assert_almost_equal(data["emb"], csv_emb_data)

# nid
assert col_names[1] in csv_data
csv_nid_data = csv_data[col_names[1]].values.tolist()
csv_nid_data = np.array(csv_nid_data, dtype=np.int32)
assert_equal(data["nid"], csv_nid_data)

# pred
assert col_names[2] in csv_data
csv_pred_data = csv_data[col_names[2]].values.tolist()
csv_pred_data = [d.split(";") for d in csv_pred_data]
csv_pred_data = np.array(csv_pred_data, dtype=np.float32)
assert_almost_equal(data["pred"], csv_pred_data)

# without renaming columns
with tempfile.TemporaryDirectory() as tmpdirname:
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=None)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["emb", "nid", "pred"])

# rename all column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
"pred": "new_pred"
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "new_pred"])

# rename part of column names
with tempfile.TemporaryDirectory() as tmpdirname:
col_name_map = {
"emb": "new_emb",
"nid": "new_nid",
}
file_prefix = os.path.join(tmpdirname, "test")
write_data_csv_file(data, file_prefix, col_name_map=col_name_map)
output_fname = f"{file_prefix}.csv"

check_write_content(output_fname, ["new_emb", "new_nid", "pred"])


if __name__ == '__main__':
test_write_data_csv_file()
test_write_data_parquet_file()
test__get_file_range()
test_worker_remap_edge_pred()
test_worker_remap_node_data("pred")
2 changes: 1 addition & 1 deletion tests/unit-tests/test_model_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def check_sparse_emb(mock_get_world_size, mock_get_rank):

for i in range(infer_world_size):
mock_get_rank.side_effect = [i] * 2
mock_get_world_size.side_effect = [train_world_size] * 2
mock_get_world_size.side_effect = [infer_world_size] * 2
load_sparse_embeds(model_path, embed_layer)
load_sparse_embs = \
{ntype: sparse_emb._tensor[th.arange(embed_layer.g.number_of_nodes(ntype))] \
Expand Down

0 comments on commit 93ece35

Please sign in to comment.