Skip to content

Commit

Permalink
Enable wholegraph sparse embedding for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Feb 15, 2024
1 parent dd9e48f commit 9c74756
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
6 changes: 4 additions & 2 deletions python/graphstorm/run/gsgnn_ep/ep_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from graphstorm.inference import GSgnnEdgePredictionInferrer
from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator
from graphstorm.dataloading import GSgnnEdgeInferData, GSgnnEdgeDataLoader
from graphstorm.utils import setup_device, get_lm_ntypes
from graphstorm.utils import setup_device, get_lm_ntypes, use_wholegraph

def get_evaluator(config): # pylint: disable=unused-argument
""" Get evaluator class
Expand All @@ -43,7 +43,9 @@ def main(config_args):
config = GSConfig(config_args)
config.verify_arguments(False)

gs.initialize(ip_config=config.ip_config, backend=config.backend)
use_wg_feats = use_wholegraph(config.part_config)
gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats)
device = setup_device(config.local_rank)

infer_data = GSgnnEdgeInferData(config.graph_name,
Expand Down
10 changes: 8 additions & 2 deletions python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@
GSgnnLinkPredictionPredefinedTestDataLoader)
from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER
from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER
from graphstorm.utils import setup_device, get_lm_ntypes
from graphstorm.utils import (
setup_device,
get_lm_ntypes,
use_wholegraph,
)

def main(config_args):
""" main function
"""
config = GSConfig(config_args)
config.verify_arguments(False)

gs.initialize(ip_config=config.ip_config, backend=config.backend)
use_wg_feats = use_wholegraph(config.part_config)
gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats)
device = setup_device(config.local_rank)

infer_data = GSgnnEdgeInferData(config.graph_name,
Expand Down
6 changes: 4 additions & 2 deletions python/graphstorm/run/gsgnn_np/np_infer_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from graphstorm.inference import GSgnnNodePredictionInferrer
from graphstorm.eval import GSgnnAccEvaluator, GSgnnRegressionEvaluator
from graphstorm.dataloading import GSgnnNodeInferData, GSgnnNodeDataLoader
from graphstorm.utils import setup_device, get_lm_ntypes
from graphstorm.utils import setup_device, get_lm_ntypes, use_wholegraph

def get_evaluator(config): # pylint: disable=unused-argument
""" Get evaluator class
Expand All @@ -42,7 +42,9 @@ def main(config_args):
config = GSConfig(config_args)
config.verify_arguments(False)

gs.initialize(ip_config=config.ip_config, backend=config.backend)
use_wg_feats = use_wholegraph(config.part_config)
gs.initialize(ip_config=config.ip_config, backend=config.backend,
use_wholegraph=config.use_wholegraph_sparse_emb or use_wg_feats)
device = setup_device(config.local_rank)

infer_data = GSgnnNodeInferData(config.graph_name,
Expand Down

0 comments on commit 9c74756

Please sign in to comment.