From 9c74756e3b2849d76845b26f3fa32bd4fd372527 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Wed, 14 Feb 2024 16:19:25 -0800 Subject: [PATCH] Enable wholegraph sparse embedding for inference --- python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py | 6 ++++-- python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py | 10 ++++++++-- python/graphstorm/run/gsgnn_np/np_infer_gnn.py | 6 ++++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py b/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py index 8a6bd6b98b..25f663781d 100644 --- a/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py @@ -22,7 +22,7 @@ from graphstorm.inference import GSgnnEdgePredictionInferrer from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator from graphstorm.dataloading import GSgnnEdgeInferData, GSgnnEdgeDataLoader -from graphstorm.utils import setup_device, get_lm_ntypes +from graphstorm.utils import setup_device, get_lm_ntypes, use_wholegraph def get_evaluator(config): # pylint: disable=unused-argument """ Get evaluator class @@ -43,7 +43,9 @@ def main(config_args): config = GSConfig(config_args) config.verify_arguments(False) - gs.initialize(ip_config=config.ip_config, backend=config.backend) + use_wg_feats = use_wholegraph(config.part_config) + gs.initialize(ip_config=config.ip_config, backend=config.backend, + use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats) device = setup_device(config.local_rank) infer_data = GSgnnEdgeInferData(config.graph_name, diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 3683e2d284..073c529dbf 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -27,7 +27,11 @@ GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER -from graphstorm.utils import setup_device, get_lm_ntypes +from graphstorm.utils import ( + setup_device, + get_lm_ntypes, + use_wholegraph, +) def main(config_args): """ main function @@ -35,7 +39,9 @@ def main(config_args): config = GSConfig(config_args) config.verify_arguments(False) - gs.initialize(ip_config=config.ip_config, backend=config.backend) + use_wg_feats = use_wholegraph(config.part_config) + gs.initialize(ip_config=config.ip_config, backend=config.backend, + use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats) device = setup_device(config.local_rank) infer_data = GSgnnEdgeInferData(config.graph_name, diff --git a/python/graphstorm/run/gsgnn_np/np_infer_gnn.py b/python/graphstorm/run/gsgnn_np/np_infer_gnn.py index 10a84e9108..4d84a2ebda 100644 --- a/python/graphstorm/run/gsgnn_np/np_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_np/np_infer_gnn.py @@ -21,7 +21,7 @@ from graphstorm.inference import GSgnnNodePredictionInferrer from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator from graphstorm.dataloading import GSgnnNodeInferData, GSgnnNodeDataLoader -from graphstorm.utils import setup_device, get_lm_ntypes +from graphstorm.utils import setup_device, get_lm_ntypes, use_wholegraph def get_evaluator(config): # pylint: disable=unused-argument """ Get evaluator class @@ -42,7 +42,9 @@ def main(config_args): config = GSConfig(config_args) config.verify_arguments(False) - gs.initialize(ip_config=config.ip_config, backend=config.backend) + use_wg_feats = use_wholegraph(config.part_config) + gs.initialize(ip_config=config.ip_config, backend=config.backend, + use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats) device = setup_device(config.local_rank) infer_data = GSgnnNodeInferData(config.graph_name,