From 6ccfa642ed30235122e92e7b4d915f942d11521d Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Sat, 27 Apr 2024 23:17:36 -0700 Subject: [PATCH] Update GSgnnData for examples (#805) *Issue #, if available:* #755 #756 *Description of changes:* Update examples: - [x] HGT - [x] GPeft - [x] TGAT By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Xiang Song --- docs/source/api/graphstorm.dataloading.rst | 5 +-- examples/customized_models/HGT/README.md | 2 +- examples/customized_models/HGT/hgt_nc.py | 41 +++++++++++-------- examples/peft_llm_gnn/main_lp.py | 22 ++++++---- examples/peft_llm_gnn/main_nc.py | 27 ++++++------ .../graphstorm_train_script_nc_config.yaml | 4 +- examples/temporal_graph_learning/main_nc.py | 29 +++++++------ python/graphstorm/model/embed.py | 4 +- python/graphstorm/model/gat_encoder.py | 4 +- python/graphstorm/model/gatv2_encoder.py | 4 +- python/graphstorm/model/hgt_encoder.py | 4 +- python/graphstorm/model/lm_embed.py | 8 ++-- python/graphstorm/model/rgat_encoder.py | 4 +- python/graphstorm/model/rgcn_encoder.py | 4 +- python/graphstorm/model/sage_encoder.py | 4 +- .../graphstorm/run/gsgnn_lp/lp_infer_gnn.py | 10 +++-- python/graphstorm/run/gsgnn_lp/lp_infer_lm.py | 10 +++-- tests/unit-tests/test_dataloading.py | 2 +- tests/unit-tests/test_gnn.py | 6 +-- 19 files changed, 106 insertions(+), 88 deletions(-) diff --git a/docs/source/api/graphstorm.dataloading.rst b/docs/source/api/graphstorm.dataloading.rst index b782602eb9..2ce324888c 100644 --- a/docs/source/api/graphstorm.dataloading.rst +++ b/docs/source/api/graphstorm.dataloading.rst @@ -33,10 +33,7 @@ DataSets :nosignatures: :template: datasettemplate.rst - GSgnnNodeTrainData - GSgnnNodeInferData - GSgnnEdgeTrainData - GSgnnEdgeInferData + GSgnnData DataLoaders ------------ diff --git a/examples/customized_models/HGT/README.md b/examples/customized_models/HGT/README.md index b2b6af248b..65b2a136e1 100644 --- a/examples/customized_models/HGT/README.md +++ b/examples/customized_models/HGT/README.md @@ -17,7 +17,7 @@ In order to plus users' own GNN models into the GraphStorm Framework, users need - Define your own loss function, or use GraphStorm's built-in loss functions that can handel common classification, regression, and link predictioin tasks. - In case having unused weights problem, modify the loss computation to include a regulation computation of all parameters -3. Use the GraphStorm's dataset, e.g., [GSgnnNodeTrainData](https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataset.py#L469) and dataloader, e.g., [GSgnnNodeDataLoader](https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataloading.py#L544) to construct distributed graph loading and mini-batch sampling. +3. Use the GraphStorm's dataset, e.g., [GSgnnData](https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataset.py#L157) and dataloader, e.g., [GSgnnNodeDataLoader](https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/dataloading/dataloading.py#L544) to construct distributed graph loading and mini-batch sampling. 4. Wrap your model in a GraphStorm trainer, e.g., [GSgnnNodePredictionTrainer](https://github.com/awslabs/graphstorm/blob/main/python/graphstorm/trainer/np_trainer.py), which will handle the training process with its fit() method. diff --git a/examples/customized_models/HGT/hgt_nc.py b/examples/customized_models/HGT/hgt_nc.py index d00f469285..284b4b7eff 100644 --- a/examples/customized_models/HGT/hgt_nc.py +++ b/examples/customized_models/HGT/hgt_nc.py @@ -12,7 +12,7 @@ from graphstorm import model as gsmodel from graphstorm.trainer import GSgnnNodePredictionTrainer from graphstorm.inference import GSgnnNodePredictionInferrer -from graphstorm.dataloading import GSgnnNodeTrainData, GSgnnNodeInferData +from graphstorm.dataloading import GSgnnData from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.eval import GSgnnAccEvaluator from graphstorm.tracker import GSSageMakerTaskTracker @@ -272,11 +272,7 @@ def main(args): node_feat_fields[node_type] = feat_names.split(',') # Define the GraphStorm training dataset - train_data = GSgnnNodeTrainData(config.graph_name, - config.part_config, - train_ntypes=config.target_ntype, - node_feat_field=node_feat_fields, - label_field=config.label_field) + train_data = GSgnnData(config.part_config) # Create input arguments for the HGT model node_dict = {} @@ -311,18 +307,29 @@ def main(args): trainer = GSgnnNodePredictionTrainer(model, topk_model_to_save=config.topk_model_to_save) trainer.setup_device(device=get_device()) + train_idxs = train_data.get_node_train_set(config.target_ntype) # Define the GraphStorm train dataloader - dataloader = GSgnnNodeDataLoader(train_data, train_data.train_idxs, fanout=config.fanout, - batch_size=config.batch_size, train_task=True) - + dataloader = GSgnnNodeDataLoader(train_data, train_idxs, fanout=config.fanout, + batch_size=config.batch_size, + node_feats=node_feat_fields, + label_field=config.label_field, + train_task=True) + + eval_ntype = config.eval_target_ntype + val_idxs = train_data.get_node_val_set(eval_ntype) + test_idxs = train_data.get_node_test_set(eval_ntype) # Optional: Define the evaluation dataloader - eval_dataloader = GSgnnNodeDataLoader(train_data, train_data.val_idxs,fanout=config.fanout, + eval_dataloader = GSgnnNodeDataLoader(train_data, val_idxs, fanout=config.fanout, batch_size=config.eval_batch_size, + node_feats=node_feat_fields, + label_field=config.label_field, train_task=False) # Optional: Define the evaluation dataloader - test_dataloader = GSgnnNodeDataLoader(train_data, train_data.test_idxs,fanout=config.fanout, + test_dataloader = GSgnnNodeDataLoader(train_data, test_idxs, fanout=config.fanout, batch_size=config.eval_batch_size, + node_feats=node_feat_fields, + label_field=config.label_field, train_task=False) # Optional: set up a evaluator @@ -351,18 +358,18 @@ def main(args): model.restore_model(best_model_path) # Create a dataset for inference. - infer_data = GSgnnNodeInferData(config.graph_name, config.part_config, - eval_ntypes=config.target_ntype, - node_feat_field=node_feat_fields, - label_field=config.label_field) + infer_data = GSgnnData(config.part_config) # Create an inference for a node task. infer = GSgnnNodePredictionInferrer(model) infer.setup_device(device=get_device()) infer.setup_evaluator(evaluator) infer.setup_task_tracker(tracker) - dataloader = GSgnnNodeDataLoader(infer_data, infer_data.test_idxs, + infer_idxs = infer_data.get_node_infer_set(eval_ntype) + dataloader = GSgnnNodeDataLoader(infer_data,infer_idxs, fanout=config.fanout, batch_size=100, + node_feats=node_feat_fields, + label_field=config.label_field, train_task=False) # Run inference on the inference dataset and save the GNN embeddings in the specified path. @@ -390,7 +397,7 @@ def main(args): default=argparse.SUPPRESS, help="Print more information. \ For customized models, MUST have this argument!!") - argparser.add_argument("--local_rank", type=int, + argparser.add_argument("--local-rank", type=int, help="The rank of the trainer. \ For customized models, MUST have this argument!!") diff --git a/examples/peft_llm_gnn/main_lp.py b/examples/peft_llm_gnn/main_lp.py index afafa5378b..8599ce6c41 100644 --- a/examples/peft_llm_gnn/main_lp.py +++ b/examples/peft_llm_gnn/main_lp.py @@ -5,7 +5,7 @@ from graphstorm.config import GSConfig from graphstorm.dataloading import GSgnnLinkPredictionDataLoader, GSgnnLinkPredictionTestDataLoader from graphstorm.eval import GSgnnMrrLPEvaluator -from graphstorm.dataloading import GSgnnLPTrainData +from graphstorm.dataloading import GSgnnData from graphstorm.utils import get_device from graphstorm.inference import GSgnnLinkPredictionInferrer from graphstorm.trainer import GSgnnLinkPredictionTrainer @@ -20,14 +20,12 @@ def main(config_args): gs.initialize(ip_config=config.ip_config, backend=config.backend, local_rank=config.local_rank) # Define the training dataset - train_data = GSgnnLPTrainData( - config.graph_name, + train_data = GSgnnData( config.part_config, - train_etypes=config.train_etype, - eval_etypes=config.eval_etype, - label_field=None, node_feat_field=config.node_feat_name, ) + train_etypes=config.train_etype + eval_etypes=config.eval_etype model = GNNLLM_LP( g=train_data.g, @@ -69,23 +67,27 @@ def main(config_args): trainer.setup_task_tracker(tracker) # create train loader with uniform negative sampling + train_idxs = train_data.get_edge_train_set(train_etypes) dataloader = GSgnnLinkPredictionDataLoader( train_data, - train_data.train_idxs, + train_idxs, fanout=config.fanout, batch_size=config.batch_size, num_negative_edges=config.num_negative_edges, + node_feats=config.node_feat_name, train_task=True, reverse_edge_types_map=config.reverse_edge_types_map, exclude_training_targets=config.exclude_training_targets, ) # create val loader + val_idxs = train_data.get_edge_val_set(eval_etypes) val_dataloader = GSgnnLinkPredictionTestDataLoader( train_data, - train_data.val_idxs, + val_idxs, batch_size=config.eval_batch_size, num_negative_edges=config.num_negative_edges, + node_feats=config.node_feat_name, fanout=config.fanout, ) @@ -112,11 +114,13 @@ def main(config_args): infer.setup_evaluator(evaluator) infer.setup_task_tracker(tracker) # Create test loader + infer_idxs = train_data.get_edge_infer_set(eval_etypes) test_dataloader = GSgnnLinkPredictionTestDataLoader( train_data, - train_data.test_idxs, + infer_idxs, batch_size=config.eval_batch_size, num_negative_edges=config.num_negative_edges_eval, + node_feats=config.node_feat_name, fanout=config.fanout, ) # Run inference on the inference dataset and save the GNN embeddings in the specified path. diff --git a/examples/peft_llm_gnn/main_nc.py b/examples/peft_llm_gnn/main_nc.py index 5325160556..23c40e3f41 100644 --- a/examples/peft_llm_gnn/main_nc.py +++ b/examples/peft_llm_gnn/main_nc.py @@ -4,7 +4,7 @@ from graphstorm.config import GSConfig from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.eval import GSgnnAccEvaluator -from graphstorm.dataloading import GSgnnNodeTrainData +from graphstorm.dataloading import GSgnnData from graphstorm.utils import get_device from graphstorm.inference import GSgnnNodePredictionInferrer from graphstorm.trainer import GSgnnNodePredictionTrainer @@ -18,14 +18,8 @@ def main(config_args): gs.initialize(ip_config=config.ip_config, backend=config.backend, local_rank=config.local_rank) # Define the training dataset - train_data = GSgnnNodeTrainData( - config.graph_name, - config.part_config, - train_ntypes=config.target_ntype, - eval_ntypes=config.eval_target_ntype, - label_field=config.label_field, - node_feat_field=config.node_feat_name, - ) + train_data = GSgnnData( + config.part_config) model = GNNLLM_NC( g=train_data.g, @@ -66,20 +60,26 @@ def main(config_args): trainer.setup_task_tracker(tracker) # create train loader + train_idxs = train_data.get_node_train_set(config.target_ntype) dataloader = GSgnnNodeDataLoader( train_data, - train_data.train_idxs, + train_idxs, fanout=config.fanout, batch_size=config.batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, train_task=True, ) # create val loader + val_idxs = train_data.get_node_val_set(config.eval_target_ntype) val_dataloader = GSgnnNodeDataLoader( train_data, - train_data.val_idxs, + val_idxs, fanout=config.fanout, batch_size=config.eval_batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, train_task=False, ) @@ -106,11 +106,14 @@ def main(config_args): infer.setup_evaluator(evaluator) infer.setup_task_tracker(tracker) # Create test loader + test_idxs = train_data.get_node_test_set(config.eval_target_ntype) test_dataloader = GSgnnNodeDataLoader( train_data, - train_data.test_idxs, + test_idxs, fanout=config.fanout, batch_size=config.eval_batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, train_task=False, ) # Run inference on the inference dataset and save the GNN embeddings in the specified path. diff --git a/examples/temporal_graph_learning/graphstorm_train_script_nc_config.yaml b/examples/temporal_graph_learning/graphstorm_train_script_nc_config.yaml index aff5d00a23..6191de4dfb 100644 --- a/examples/temporal_graph_learning/graphstorm_train_script_nc_config.yaml +++ b/examples/temporal_graph_learning/graphstorm_train_script_nc_config.yaml @@ -41,6 +41,6 @@ gsf: use_self_loop: true udf: save_result_path: tgat_nc_gpu - eval_target_ntype: - - paper + eval_target_ntypes: + - paper version: 1.0 diff --git a/examples/temporal_graph_learning/main_nc.py b/examples/temporal_graph_learning/main_nc.py index 77ea4a54a3..e5a5469838 100644 --- a/examples/temporal_graph_learning/main_nc.py +++ b/examples/temporal_graph_learning/main_nc.py @@ -4,7 +4,7 @@ from graphstorm.config import GSConfig from graphstorm.dataloading import GSgnnNodeDataLoader from graphstorm.eval import GSgnnAccEvaluator -from graphstorm.dataloading import GSgnnNodeTrainData +from graphstorm.dataloading import GSgnnData from graphstorm.utils import get_device from graphstorm.trainer import GSgnnNodePredictionTrainer @@ -18,14 +18,8 @@ def main(config_args): local_rank=config.local_rank) # Define the training dataset - train_data = GSgnnNodeTrainData( - config.graph_name, - config.part_config, - train_ntypes=config.target_ntype, - eval_ntypes=config.eval_target_ntype, - label_field=config.label_field, - node_feat_field=config.node_feat_name, - ) + train_data = GSgnnData( + config.part_config) # Define TGAT model model = create_rgcn_model_for_nc(train_data.g, config) @@ -33,7 +27,7 @@ def main(config_args): # Create a trainer for NC tasks. trainer = GSgnnNodePredictionTrainer( - model, gs.get_rank(), topk_model_to_save=config.topk_model_to_save + model, topk_model_to_save=config.topk_model_to_save ) if config.restore_model_path is not None: @@ -57,29 +51,38 @@ def main(config_args): trainer.setup_evaluator(evaluator) # create train loader + train_idxs = train_data.get_node_train_set(config.target_ntype) dataloader = GSgnnNodeDataLoader( train_data, - train_data.train_idxs, + train_idxs, fanout=config.fanout, batch_size=config.batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, train_task=True, ) # create val loader + val_idxs = train_data.get_node_val_set(config.eval_target_ntypes) val_dataloader = GSgnnNodeDataLoader( train_data, - train_data.val_idxs, + val_idxs, fanout=config.fanout, batch_size=config.eval_batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, train_task=False, ) # create test loader + test_idxs = train_data.get_node_test_set(config.eval_target_ntypes) test_dataloader = GSgnnNodeDataLoader( train_data, - train_data.test_idxs, + test_idxs, fanout=config.fanout, batch_size=config.eval_batch_size, + node_feats=config.node_feat_name, + label_field=config.label_field, train_task=False, ) diff --git a/python/graphstorm/model/embed.py b/python/graphstorm/model/embed.py index fd12b49759..078e5d9915 100644 --- a/python/graphstorm/model/embed.py +++ b/python/graphstorm/model/embed.py @@ -205,9 +205,9 @@ class GSNodeEncoderInputLayer(GSNodeInputLayer): from graphstorm import get_node_feat_size from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnEdgeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/model/gat_encoder.py b/python/graphstorm/model/gat_encoder.py index 1b69d7ec99..fbd063b253 100644 --- a/python/graphstorm/model/gat_encoder.py +++ b/python/graphstorm/model/gat_encoder.py @@ -137,10 +137,10 @@ class GATEncoder(GraphConvEncoder): from graphstorm.model.gat_encoder import GATEncoder from graphstorm.model.node_decoder import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/model/gatv2_encoder.py b/python/graphstorm/model/gatv2_encoder.py index 6918e1347e..c30fb8070b 100644 --- a/python/graphstorm/model/gatv2_encoder.py +++ b/python/graphstorm/model/gatv2_encoder.py @@ -141,10 +141,10 @@ class GATv2Encoder(GraphConvEncoder): from graphstorm.model.gat_encoder import GATv2Encoder from graphstorm.model.node_decoder import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/model/hgt_encoder.py b/python/graphstorm/model/hgt_encoder.py index 1a248ba69b..de203335fb 100644 --- a/python/graphstorm/model/hgt_encoder.py +++ b/python/graphstorm/model/hgt_encoder.py @@ -314,10 +314,10 @@ class HGTEncoder(GraphConvEncoder): from graphstorm.model.hgt_encoder import HGTEncoder from graphstorm.model.edge_decoder import MLPEdgeDecoder from graphstorm.model import GSgnnEdgeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnEdgeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/model/lm_embed.py b/python/graphstorm/model/lm_embed.py index 2eb7178919..8e32c019bf 100644 --- a/python/graphstorm/model/lm_embed.py +++ b/python/graphstorm/model/lm_embed.py @@ -492,7 +492,7 @@ class GSPureLMNodeInputLayer(GSNodeInputLayer): .. code:: python from graphstorm.model import GSgnnNodeModel, GSPureLMNodeInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData node_lm_configs = [ { @@ -502,7 +502,7 @@ class GSPureLMNodeInputLayer(GSNodeInputLayer): "node_types": ['a'] } ] - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnNodeModel(...) lm_train_nodes=10 encoder = GSPureLMNodeInputLayer(g=np_data.g, node_lm_configs=node_lm_configs, @@ -687,8 +687,8 @@ class GSLMNodeEncoderInputLayer(GSNodeEncoderInputLayer): from graphstorm import get_node_feat_size from graphstorm.model import GSgnnNodeModel, GSLMNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData - np_data = GSgnnNodeTrainData(...) + from graphstorm.dataloading import GSgnnData + np_data = GSgnnData(...) model = GSgnnNodeModel(...) feat_size = get_node_feat_size(np_data.g, 'feat') node_lm_configs = [{"lm_type": "bert", diff --git a/python/graphstorm/model/rgat_encoder.py b/python/graphstorm/model/rgat_encoder.py index 23728eac7e..14c1ecc395 100644 --- a/python/graphstorm/model/rgat_encoder.py +++ b/python/graphstorm/model/rgat_encoder.py @@ -246,10 +246,10 @@ class RelationalGATEncoder(GraphConvEncoder): from graphstorm.model.rgat_encoder import RelationalGATEncoder from graphstorm.model.node_decoder import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/model/rgcn_encoder.py b/python/graphstorm/model/rgcn_encoder.py index ae0db75024..a708dfe289 100644 --- a/python/graphstorm/model/rgcn_encoder.py +++ b/python/graphstorm/model/rgcn_encoder.py @@ -293,10 +293,10 @@ class RelationalGCNEncoder(GraphConvEncoder): from graphstorm.model.rgcn_encoder import RelationalGCNEncoder from graphstorm.model.node_decoder import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/model/sage_encoder.py b/python/graphstorm/model/sage_encoder.py index 3b2da1de43..3fddbdd9e0 100644 --- a/python/graphstorm/model/sage_encoder.py +++ b/python/graphstorm/model/sage_encoder.py @@ -177,10 +177,10 @@ class SAGEEncoder(GraphConvEncoder): from graphstorm.model.sage_encoder import SAGEEncoder from graphstorm.model.node_decoder import EntityClassifier from graphstorm.model import GSgnnNodeModel, GSNodeEncoderInputLayer - from graphstorm.dataloading import GSgnnNodeTrainData + from graphstorm.dataloading import GSgnnData from graphstorm.model import do_full_graph_inference - np_data = GSgnnNodeTrainData(...) + np_data = GSgnnData(...) model = GSgnnNodeModel(alpha_l2norm=0) feat_size = get_node_feat_size(np_data.g, 'feat') diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 11c01ce2b0..591a78609e 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -53,21 +53,23 @@ def main(config_args): model_layer_to_load=config.restore_model_layers) infer = GSgnnLinkPredictionInferrer(model) infer.setup_device(device=get_device()) - test_idxs = infer_data.get_edge_test_set(config.eval_etype) if not config.no_validation: + infer_idxs = infer_data.get_edge_test_set(config.eval_etype) infer.setup_evaluator( GSgnnMrrLPEvaluator(config.eval_frequency, infer_data, config.num_negative_edges_eval, config.lp_decoder_type)) - assert len(test_idxs) > 0, "There is not test data for evaluation." + assert len(infer_idxs) > 0, "There is not test data for evaluation." + else: + infer_idxs = infer_data.get_edge_infer_set(config.eval_etype) tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) # We only support full-graph inference for now. if config.eval_etypes_negative_dstnode is not None: # The negatives used in evaluation is fixed. dataloader = GSgnnLinkPredictionPredefinedTestDataLoader( - infer_data, test_idxs, + infer_data, infer_idxs, batch_size=config.eval_batch_size, fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, fanout=config.eval_fanout, @@ -82,7 +84,7 @@ def main(config_args): 'Supported test negative samplers include ' f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') - dataloader = test_dataloader_cls(infer_data, test_idxs, + dataloader = test_dataloader_cls(infer_data, infer_idxs, batch_size=config.eval_batch_size, num_negative_edges=config.num_negative_edges_eval, fanout=config.eval_fanout, diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py index d40a58d6a1..af07ab72c0 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py @@ -46,21 +46,23 @@ def main(config_args): model_layer_to_load=config.restore_model_layers) infer = GSgnnLinkPredictionInferrer(model) infer.setup_device(device=get_device()) - test_idxs = infer_data.get_edge_test_set(config.eval_etype) if not config.no_validation: + infer_idxs = infer_data.get_edge_test_set(config.eval_etype) infer.setup_evaluator( GSgnnMrrLPEvaluator(config.eval_frequency, infer_data, config.num_negative_edges_eval, config.lp_decoder_type)) - assert len(test_idxs) > 0, "There is not test data for evaluation." + assert len(infer_idxs) > 0, "There is not test data for evaluation." + else: + infer_idxs = infer_data.get_edge_infer_set(config.eval_etype) tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) # We only support full-graph inference for now. if config.eval_etypes_negative_dstnode is not None: # The negatives used in evaluation is fixed. dataloader = GSgnnLinkPredictionPredefinedTestDataLoader( - infer_data, test_idxs, + infer_data, infer_idxs, batch_size=config.eval_batch_size, fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, node_feats=config.node_feat_name) @@ -74,7 +76,7 @@ def main(config_args): 'Supported test negative samplers include ' f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') - dataloader = test_dataloader_cls(infer_data, test_idxs, + dataloader = test_dataloader_cls(infer_data, infer_idxs, batch_size=config.eval_batch_size, num_negative_edges=config.num_negative_edges_eval, node_feats=config.node_feat_name) diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index d1c0f352a0..10bd705847 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -2099,7 +2099,7 @@ def test_GSgnnTrainData_homogeneous(): with tempfile.TemporaryDirectory() as tmpdirname: # generate the test dummy homogeneous distributed graph and - # test if it is possible to create GSgnnNodeTrainData on homogeneous graph + # test if it is possible to create GSgnnData on homogeneous graph _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname, is_homo=True) diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index 1d4b119061..f5c5771085 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -235,7 +235,7 @@ def check_node_prediction(model, data, is_homo=False): ---------- model: GSgnnNodeModel Node model - data: GSgnnNodeTrainData + data: GSgnnData Train data """ g = data.g @@ -311,7 +311,7 @@ def check_node_prediction_with_reconstruct(model, data, construct_feat_ntype, tr ---------- model: GSgnnNodeModel Node model - data: GSgnnNodeTrainData + data: GSgnnData Train data """ target_ntype = train_ntypes[0] @@ -390,7 +390,7 @@ def check_mlp_node_prediction(model, data): ---------- model: GSgnnNodeModel Node model - data: GSgnnNodeTrainData + data: GSgnnData Train data """ g = data.g