diff --git a/python/graphstorm/dataloading/__init__.py b/python/graphstorm/dataloading/__init__.py index 1aed57bb12..20abce548d 100644 --- a/python/graphstorm/dataloading/__init__.py +++ b/python/graphstorm/dataloading/__init__.py @@ -24,8 +24,9 @@ from .dataloading import GSgnnAllEtypeLinkPredictionDataLoader from .dataloading import GSgnnEdgeDataLoader from .dataloading import GSgnnNodeDataLoader, GSgnnNodeSemiSupDataLoader -from .dataloading import GSgnnLinkPredictionTestDataLoader -from .dataloading import GSgnnLinkPredictionJointTestDataLoader +from .dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from .dataloading import (FastGSgnnLinkPredictionDataLoader, FastGSgnnLPLocalJointNegDataLoader, FastGSgnnLPJointNegDataLoader, @@ -43,7 +44,8 @@ BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER, - BUILTIN_LP_LOCALJOINT_NEG_SAMPLER) + BUILTIN_LP_LOCALJOINT_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER) from .dataloading import BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER from .dataloading import BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER from .dataloading import (BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER, diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index afa1650d28..517ec236aa 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -488,6 +488,10 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase): The node types that requires to construct node features. construct_feat_fanout : int The fanout required to construct node features. + edge_dst_negative_field: str or dict of str + The feature field(s) that store the hard negative edges for each edge type. + num_hard_negs: int or dict of int + The number of hard negatives per positive edge for each edge type Examples ------------ @@ -510,7 +514,9 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase): def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges, device='cpu', train_task=True, reverse_edge_types_map=None, exclude_training_targets=False, edge_mask_for_gnn_embeddings='train_mask', - construct_feat_ntype=None, construct_feat_fanout=5): + construct_feat_ntype=None, construct_feat_fanout=5, + edge_dst_negative_field=None, + num_hard_negs=None): super().__init__(dataset, target_idx, fanout) self._device = device for etype in target_idx: @@ -524,7 +530,9 @@ def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges, reverse_edge_types_map=reverse_edge_types_map, edge_mask_for_gnn_embeddings=edge_mask_for_gnn_embeddings, construct_feat_ntype=construct_feat_ntype, - construct_feat_fanout=construct_feat_fanout) + construct_feat_fanout=construct_feat_fanout, + edge_dst_negative_field=edge_dst_negative_field, + num_hard_negs=num_hard_negs) def _prepare_negative_sampler(self, num_negative_edges): # the default negative sampler is uniform sampler @@ -535,7 +543,8 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, num_negative_edges, batch_size, device, train_task=True, exclude_training_targets=False, reverse_edge_types_map=None, edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None, - construct_feat_fanout=5): + construct_feat_fanout=5, edge_dst_negative_field=None, + num_hard_negs=None): g = dataset.g if construct_feat_ntype is None: construct_feat_ntype = [] @@ -556,6 +565,11 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, sampler = MultiLayerNeighborSamplerForReconstruct(sampler, dataset, construct_feat_ntype, construct_feat_fanout) negative_sampler = self._prepare_negative_sampler(num_negative_edges) + if edge_dst_negative_field is not None: + negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges, + edge_dst_negative_field, + negative_sampler, + num_hard_negs) # edge loader if train_task: @@ -642,7 +656,8 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, num_negative_edges, batch_size, device, train_task=True, exclude_training_targets=False, reverse_edge_types_map=None, edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None, - construct_feat_fanout=5): + construct_feat_fanout=5, edge_dst_negative_field=None, + num_hard_negs=None): g = dataset.g if construct_feat_ntype is None: construct_feat_ntype = [] @@ -663,6 +678,11 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, sampler = MultiLayerNeighborSamplerForReconstruct(sampler, dataset, construct_feat_ntype, construct_feat_fanout) negative_sampler = self._prepare_negative_sampler(num_negative_edges) + if edge_dst_negative_field is not None: + negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges, + edge_dst_negative_field, + negative_sampler, + num_hard_negs) # edge loader if train_task: @@ -975,11 +995,9 @@ class GSgnnLinkPredictionTestDataLoader(): When test is huge, using fixed_test_size can save validation and test time. Default: None. - fixed_edge_dst_negative_field: str or list of str - The feature field(s) that store the fixed negative set for each edge. """ def __init__(self, dataset, target_idx, batch_size, num_negative_edges, - fanout=None, fixed_test_size=None, fixed_edge_dst_negative_field=None): + fanout=None, fixed_test_size=None): self._data = dataset self._fanout = fanout for etype in target_idx: @@ -996,7 +1014,6 @@ def __init__(self, dataset, target_idx, batch_size, num_negative_edges, "is %d, which is smaller than the expected" "test size %d, force it to %d", etype, len(t_idx), self._fixed_test_size[etype], len(t_idx)) - self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field self._negative_sampler = self._prepare_negative_sampler(num_negative_edges) self._reinit_dataset() @@ -1014,13 +1031,7 @@ def _reinit_dataset(self): def _prepare_negative_sampler(self, num_negative_edges): # the default negative sampler is uniform sampler self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER - - if self._fixed_edge_dst_negative_field: - negative_sampler = GSFixedEdgeDstNegativeSampler(self._fixed_edge_dst_negative_field) - self._neg_sample_type = BUILTIN_LP_FIXED_NEG_SAMPLER - else: - negative_sampler = GlobalUniform(num_negative_edges) - self._neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER + negative_sampler = GlobalUniform(num_negative_edges) return negative_sampler def __iter__(self): @@ -1069,10 +1080,41 @@ def _prepare_negative_sampler(self, num_negative_edges): # the default negative sampler is uniform sampler negative_sampler = JointUniform(num_negative_edges) self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER - if self._fixed_edge_dst_negative_field: - negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges, - self._fixed_edge_dst_negative_field, - negative_sampler) + return negative_sampler + +class GSgnnLinkPredictionPredefinedTestDataLoader(GSgnnLinkPredictionTestDataLoader): + """ Link prediction minibatch dataloader for validation and test + with predefined negatives. + + Parameters + ----------- + dataset: GSgnnEdgeData + The GraphStorm edge dataset + target_idx : dict of Tensors + The target edges for prediction + batch_size: int + Batch size + fanout: int + Evaluation fanout for computing node embedding + fixed_test_size: int + Fixed number of test data used in evaluation. + If it is none, use the whole testset. + When test is huge, using fixed_test_size + can save validation and test time. + Default: None. + fixed_edge_dst_negative_field: str or list of str + The feature field(s) that store the fixed negative set for each edge. + """ + def __init__(self, dataset, target_idx, batch_size, fixed_edge_dst_negative_field, + fanout=None, fixed_test_size=None): + self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field + super().__init__(dataset, target_idx, batch_size, + num_negative_edges=0, # num_negative_edges is not used + fanout=fanout, fixed_test_size=fixed_test_size) + + def _prepare_negative_sampler(self, _): + negative_sampler = GSFixedEdgeDstNegativeSampler(self._fixed_edge_dst_negative_field) + self._neg_sample_type = BUILTIN_LP_FIXED_NEG_SAMPLER return negative_sampler ################ Minibatch DataLoader (Node classification) ####################### diff --git a/python/graphstorm/dataloading/sampler.py b/python/graphstorm/dataloading/sampler.py index 836f563126..fc474d85e9 100644 --- a/python/graphstorm/dataloading/sampler.py +++ b/python/graphstorm/dataloading/sampler.py @@ -179,7 +179,7 @@ def _generate(self, g, eids, canonical_etype): hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k] return src, neg - def gen_neg_pairs(self, g, pos_pairs): + def gen_neg_pairs(self, _): """ TODO: Do not support generating negative pairs for evaluation in the same way as generating negative pairs for training now. Please use GSFixedEdgeDstNegativeSampler instead. diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py index 143fd58e06..de18c5cc60 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py @@ -30,8 +30,9 @@ GSgnnLPInBatchJointNegDataLoader) from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, @@ -129,10 +130,14 @@ def main(config_args): raise ValueError('Unknown negative sampler') dataloader = dataloader_cls(train_data, train_data.train_idxs, [], config.batch_size, config.num_negative_edges, device, - train_task=True) + train_task=True, + edge_dst_negative_field=config.train_etypes_negative_dstnode, + num_hard_negs=config.num_train_hard_negatives) # TODO(zhengda) let's use full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader @@ -143,11 +148,19 @@ def main(config_args): val_dataloader = None test_dataloader = None if len(train_data.val_idxs) > 0: - val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, - config.eval_batch_size, config.num_negative_edges_eval) + if config.eval_etypes_negative_dstnode is not None: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, config.eval_etypes_negative_dstnode) + else: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, config.num_negative_edges_eval) if len(train_data.test_idxs) > 0: - test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, - config.eval_batch_size, config.num_negative_edges_eval) + if config.eval_etypes_negative_dstnode is not None: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, config.eval_etypes_negative_dstnode) + else: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, config.num_negative_edges_eval) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 92a40dc737..1a843e18fa 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -30,8 +30,9 @@ GSgnnLPInBatchJointNegDataLoader) from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, @@ -152,10 +153,14 @@ def main(config_args): reverse_edge_types_map=config.reverse_edge_types_map, exclude_training_targets=config.exclude_training_targets, construct_feat_ntype=config.construct_feat_ntype, - construct_feat_fanout=config.construct_feat_fanout) + construct_feat_fanout=config.construct_feat_fanout, + edge_dst_negative_field=config.train_etypes_negative_dstnode, + num_hard_negs=config.num_train_hard_negatives) # TODO(zhengda) let's use full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader @@ -166,13 +171,27 @@ def main(config_args): val_dataloader = None test_dataloader = None if len(train_data.val_idxs) > 0: - val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, - fixed_test_size=config.fixed_test_size) + if config.eval_etypes_negative_dstnode is not None: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout, + fixed_test_size=config.fixed_test_size) + else: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size) if len(train_data.test_idxs) > 0: - test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, - fixed_test_size=config.fixed_test_size) + if config.eval_etypes_negative_dstnode is not None: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout, + fixed_test_size=config.fixed_test_size) + else: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 50a2e97acc..3d6a854999 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -22,8 +22,9 @@ from graphstorm.inference import GSgnnLinkPredictionInferrer from graphstorm.eval import GSgnnMrrLPEvaluator from graphstorm.dataloading import GSgnnEdgeInferData -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER from graphstorm.utils import setup_device, get_lm_ntypes @@ -58,7 +59,9 @@ def main(config_args): tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) # We only support full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader @@ -67,10 +70,16 @@ def main(config_args): 'Supported test negative samplers include ' f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') - dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, - batch_size=config.eval_batch_size, - num_negative_edges=config.num_negative_edges_eval, - fanout=config.eval_fanout) + if config.eval_etypes_negative_dstnode is not None: + dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout) + else: + dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + num_negative_edges=config.num_negative_edges_eval, + fanout=config.eval_fanout) infer.infer(infer_data, dataloader, save_embed_path=config.save_embed_path, edge_mask_for_gnn_embeddings=None if config.no_validation else \ diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py index e196d3fd83..baec5082e6 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py @@ -23,8 +23,9 @@ from graphstorm.inference import GSgnnLinkPredictionInferrer from graphstorm.eval import GSgnnMrrLPEvaluator from graphstorm.dataloading import GSgnnEdgeInferData -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER from graphstorm.utils import setup_device @@ -58,7 +59,9 @@ def main(config_args): tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) # We only support full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader @@ -67,9 +70,14 @@ def main(config_args): 'Supported test negative samplers include ' f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') - dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, - batch_size=config.eval_batch_size, - num_negative_edges=config.num_negative_edges_eval) + if config.eval_etypes_negative_dstnode is not None: + dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode) + else: + dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + num_negative_edges=config.num_negative_edges_eval) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 8c7194b37f..f965858115 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -49,12 +49,14 @@ FastGSgnnLPJointNegDataLoader, FastGSgnnLPLocalUniformNegDataLoader, FastGSgnnLPLocalJointNegDataLoader) -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import DistillDataloaderGenerator, DistillDataManager from graphstorm.dataloading import DistributedFileSampler -from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER -from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER +from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_JOINT_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER) from graphstorm.dataloading.sampler import InbatchJointUniform from graphstorm.dataloading.sampler import GlobalUniform @@ -664,6 +666,59 @@ def test_GSgnnLinkPredictionTestDataLoader(batch_size, num_negative_edges): # after test pass, destroy all process group th.distributed.destroy_process_group() +def test_GSgnnLinkPredictionPredefinedTestDataLoader(batch_size): + th.distributed.init_process_group(backend='gloo', + init_method='tcp://127.0.0.1:23456', + rank=0, + world_size=1) + test_etypes = [("n0", "r1", "n1"), ("n0", "r0", "n1")] + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + lp_data = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config, + eval_etypes=test_etypes) + g = lp_data.g + g.edges[("n0", "r1", "n1")].data["neg"] = th.randint(g.num_nodes("n1"), + (g.num_edges(("n0", "r1", "n1")), 10)) + g.edges[("n0", "r0", "n1")].data["neg"] = th.randint(g.num_nodes("n1"), + (g.num_edges(("n0", "r0", "n1")), 10)) + + dataloader = GSgnnLinkPredictionPredefinedTestDataLoader( + lp_data, + target_idx=lp_data.infer_idxs, # use train edges as val or test edges + batch_size=batch_size, + num_negative_edges=0, + fixed_edge_dst_negative_field="neg") + + total_edges = {etype: len(lp_data.infer_idxs[etype]) for etype in test_etypes} + num_pos_edges = {etype: 0 for etype in test_etypes} + for pos_neg_tuple, sample_type in dataloader: + assert sample_type == BUILTIN_LP_FIXED_NEG_SAMPLER + assert isinstance(pos_neg_tuple, dict) + assert len(pos_neg_tuple) == 2 + for canonical_etype, pos_neg in pos_neg_tuple.items(): + assert len(pos_neg) == 4 + pos_src, _, pos_dst, neg_dst = pos_neg + assert pos_src.shape == pos_dst.shape + assert pos_src.shape[0] == batch_size \ + if num_pos_edges[canonical_etype] + batch_size < total_edges[canonical_etype] \ + else total_edges[canonical_etype] - num_pos_edges[canonical_etype] + eid = lp_data.train_idxs[canonical_etype][num_pos_edges[canonical_etype]: \ + num_pos_edges[canonical_etype]+batch_size] \ + if num_pos_edges[canonical_etype]+batch_size < total_edges[canonical_etype] \ + else lp_data.train_idxs[canonical_etype] \ + [num_pos_edges[canonical_etype]:] + src, dst = g.find_edges(eid, etype=canonical_etype) + assert_equal(pos_src.numpy(), src.numpy()) + assert_equal(pos_dst.numpy(), dst.numpy()) + assert len(neg_dst.shape) == 2 + assert neg_dst.shape[1] == 10 + assert_equal(neg_dst.numpy(), g.edges[canonical_etype].data["neg"][eid].numpy()) + + num_pos_edges[canonical_etype] += batch_size + # after test pass, destroy all process group + th.distributed.destroy_process_group() + # initialize the torch distributed environment @pytest.mark.parametrize("batch_size", [1, 10, 128]) @pytest.mark.parametrize("num_negative_edges", [1, 16, 128]) @@ -1788,6 +1843,8 @@ def test_inbatch_joint_neg_sampler(num_pos, num_neg): if __name__ == '__main__': + test_GSgnnLinkPredictionPredefinedTestDataLoader(1) + test_GSgnnLinkPredictionPredefinedTestDataLoader(10) test_edge_fixed_dst_negative_sample_gen_neg_pairs() test_hard_edge_dst_negative_sample_generate_complex_case() test_hard_edge_dst_negative_sample_generate()