Skip to content

Commit

Permalink
Update builtin train/inference entries to accept unknown arguments (#681
Browse files Browse the repository at this point in the history
)

*Issue #, if available:*
As we chain train/inference with node id remapping in the train and
inference pipeline, both train/inference entries and remap entry should
be more robust to unknown arguments.

Related issue: #674 

*Description of changes:*
Update builtin train/inference entries to accept unknown arguments.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Dec 11, 2023
1 parent 79ca346 commit ca0ed1f
Show file tree
Hide file tree
Showing 15 changed files with 41 additions and 28 deletions.
3 changes: 2 additions & 1 deletion examples/customized_models/HGT/hgt_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def main(args):
argparser.add_argument("--local_rank", type=int,
help="The rank of the trainer. \
For customized models, MUST have this argument!!")
args = argparser.parse_args()

# Ignore unknown args to make script more robust to input arguments
args, _ = argparser.parse_known_args()
print(args)
main(args)
4 changes: 2 additions & 2 deletions examples/temporal_graph_learning/main_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def generate_parser():
if __name__ == "__main__":
arg_parser = generate_parser()

args = arg_parser.parse_args()
print(args)
# Ignore unknown args to make script more robust to input arguments
args, _ = arg_parser.parse_known_args()
main(args)


6 changes: 3 additions & 3 deletions python/graphstorm/run/gsgnn_dt/distill_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
print(args)
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser = generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/ep_infer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/gsgnn_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_ep/gsgnn_lm_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/gsgnn_lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_np/gsgnn_np.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,5 +182,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
5 changes: 3 additions & 2 deletions python/graphstorm/run/gsgnn_np/np_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,6 @@ def generate_parser():
if __name__ == '__main__':
arg_parser=generate_parser()

args = arg_parser.parse_args()
main(args)
# Ignore unknown args to make script more robust to input arguments
gs_args, _ = arg_parser.parse_known_args()
main(gs_args)
1 change: 1 addition & 0 deletions sagemaker/launch/launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def run_job(input_args, image, unknownargs):
# We must handle cases like
# --target-etype query,clicks,asin query,search,asin
# --feat-name ntype0:feat0 ntype1:feat1
# --column-names nid,~id emb,embedding
unknow_idx = 0
while unknow_idx < len(unknownargs):
print(unknownargs[unknow_idx])
Expand Down

0 comments on commit ca0ed1f

Please sign in to comment.