diff --git a/docs/source/configuration/configuration-run.rst b/docs/source/configuration/configuration-run.rst index 7bfd3c13fc..2fd19f76ae 100644 --- a/docs/source/configuration/configuration-run.rst +++ b/docs/source/configuration/configuration-run.rst @@ -287,6 +287,11 @@ GraphStorm provides a set of parameters to control model evaluation. - Yaml: ``no_validation: true`` - Argument: ``--no-validation true`` - Default value: ``false`` +- **fixed_test_size**: Set the number of validation and test data used during link prediction training evaluaiotn. This is useful for reducing the overhead of doing link prediction evaluation when the graph size is large. + + - Yaml: ``fixed_test_size: 100000`` + - Argument: ``--fixed-test-size 100000`` + - Default value: None, Use the full validation and test set. Language Model Specific Configurations --------------------------------------------------- diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 3403e7d18c..cf28e2941e 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -264,6 +264,7 @@ def verify_arguments(self, is_train): _ = self.decoder_edge_feat # Evaluation + _ = self.fixed_test_size _ = self.eval_fanout _ = self.use_mini_batch_infer _ = self.eval_batch_size @@ -801,6 +802,25 @@ def eval_fanout(self): # By default use -1 as full neighbor return [-1] * self.num_layers + @property + def fixed_test_size(self): + """ Fixed number of test data used in evaluation + + This is useful for reducing the overhead of doing link prediction evaluation. + + TODO: support fixed_test_size in + node prediction and edge prediction tasks. + """ + # pylint: disable=no-member + if hasattr(self, "_fixed_test_size"): + assert self._fixed_test_size > 0, \ + "fixed_test_size must be larger than 0" + return self._fixed_test_size + + # Use the full test set + return None + + @property def textual_data_path(self): """ distillation textual data path @@ -2339,6 +2359,8 @@ def _add_link_prediction_args(parser): help="Link prediction decoder type.") group.add_argument("--num-negative-edges", type=int, default=argparse.SUPPRESS, help="Number of edges consider for the negative batch of edges.") + group.add_argument("--fixed-test-size", type=int, default=argparse.SUPPRESS, + help="Fixed number of test data used in evaluation.") group.add_argument("--num-negative-edges-eval", type=int, default=argparse.SUPPRESS, help="Number of edges consider for the negative " "batch of edges for the model evaluation. " diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index 21c6c19de2..d56dec8227 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -965,8 +965,15 @@ class GSgnnLinkPredictionTestDataLoader(): The number of negative edges per positive edge fanout: int Evaluation fanout for computing node embedding + fixed_test_size: int + Fixed number of test data used in evaluation. + If it is none, use the whole testset. + When test is huge, using fixed_test_size + can save validation and test time. + Default: None. """ - def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None): + def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None, + fixed_test_size=None): self._data = dataset self._fanout = fanout for etype in target_idx: @@ -974,6 +981,15 @@ def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=N "edge type {} does not exist in the graph".format(etype) self._batch_size = batch_size self._target_idx = target_idx + self._fixed_test_size = {} + for etype, t_idx in target_idx.items(): + self._fixed_test_size[etype] = fixed_test_size \ + if fixed_test_size is not None else len(t_idx) + if self._fixed_test_size[etype] > len(t_idx): + logging.warning("The size of the test set of etype %s" \ + "is %d, which is smaller than the expected" + "test size %d, force it to %d", + etype, len(t_idx), self._fixed_test_size[etype], len(t_idx)) self._negative_sampler = self._prepare_negative_sampler(num_negative_edges) self._reinit_dataset() @@ -982,6 +998,11 @@ def _reinit_dataset(self): """ self._current_pos = {etype:0 for etype, _ in self._target_idx.items()} self.remaining_etypes = list(self._target_idx.keys()) + for etype, t_idx in self._target_idx.items(): + # If the expected test size is smaller than the size of test set + # shuffle the test ids + if self._fixed_test_size[etype] < len(t_idx): + self._target_idx[etype] = self._target_idx[etype][th.randperm(len(t_idx))] def _prepare_negative_sampler(self, num_negative_edges): # the default negative sampler is uniform sampler @@ -998,8 +1019,10 @@ def _next_data(self, etype): """ g = self._data.g current_pos = self._current_pos[etype] - end_of_etype = current_pos + self._batch_size >= len(self._target_idx[etype]) - pos_eids = self._target_idx[etype][current_pos:] if end_of_etype \ + end_of_etype = current_pos + self._batch_size >= self._fixed_test_size[etype] + + pos_eids = self._target_idx[etype][current_pos:self._fixed_test_size[etype]] \ + if end_of_etype \ else self._target_idx[etype][current_pos:current_pos+self._batch_size] pos_pairs = g.find_edges(pos_eids, etype=etype) pos_neg_tuple = self._negative_sampler.gen_neg_pairs(g, {etype:pos_pairs}) diff --git a/python/graphstorm/dataloading/sampler.py b/python/graphstorm/dataloading/sampler.py index 095a288a2a..53b9d46794 100644 --- a/python/graphstorm/dataloading/sampler.py +++ b/python/graphstorm/dataloading/sampler.py @@ -409,7 +409,8 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None): output_device=self.output_device, exclude_edges=exclude_eids, ) - eid = frontier.edata[EID] + eid = {etype: frontier.edges[etype].data[EID] \ + for etype in frontier.canonical_etypes} new_eid = dict(eid) if self.mask is not None: new_edges = {} @@ -436,7 +437,10 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None): else: new_frontier = frontier block = to_block(new_frontier, seed_nodes) - block.edata[EID] = new_eid + # When there is only one etype + # we can not use block.edata[EID] = new_eid + for etype in block.canonical_etypes: + block.edges[etype].data[EID] = new_eid[etype] seed_nodes = block.srcdata[NID] blocks.insert(0, block) diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 11f14d2575..633af096b7 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -165,10 +165,12 @@ def main(config_args): test_dataloader = None if len(train_data.val_idxs) > 0: val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout) + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size) if len(train_data.test_idxs) > 0: test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout) + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 7108fe560d..18613e2314 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -641,6 +641,23 @@ def test_GSgnnLinkPredictionTestDataLoader(batch_size, num_negative_edges): assert neg_src.shape[1] == num_negative_edges assert th.all(neg_src < g.number_of_nodes(canonical_etype[0])) + fixed_test_size = 10 + dataloader = GSgnnLinkPredictionTestDataLoader( + lp_data, + target_idx=lp_data.train_idxs, # use train edges as val or test edges + batch_size=batch_size, + num_negative_edges=num_negative_edges,fixed_test_size=fixed_test_size) + num_samples = 0 + for pos_neg_tuple, sample_type in dataloader: + num_samples += 1 + assert isinstance(pos_neg_tuple, dict) + assert len(pos_neg_tuple) == 1 + for _, pos_neg in pos_neg_tuple.items(): + pos_src, _, pos_dst, _ = pos_neg + assert len(pos_src) <= batch_size + + assert num_samples == -(-fixed_test_size // batch_size) * 2 + # after test pass, destroy all process group th.distributed.destroy_process_group()