From e248705872019b02a78b6ec577350cb43359e629 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 28 Dec 2023 16:14:03 +0800 Subject: [PATCH] Revert "fix inference nodes." This reverts commit 0375aceb1a3bafed6943d72e8c9cd0f7661dccc8. --- python/graphstorm/dataloading/dataset.py | 27 ++++++++---------------- 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/python/graphstorm/dataloading/dataset.py b/python/graphstorm/dataloading/dataset.py index 472e6c3897..227ceff8bc 100644 --- a/python/graphstorm/dataloading/dataset.py +++ b/python/graphstorm/dataloading/dataset.py @@ -1133,28 +1133,19 @@ def prepare_data(self, g): # If there are test data globally, we should add them to the dict. if test_idx is not None and dist_sum(len(test_idx)) > 0: test_idxs[ntype] = test_idx - elif test_idx is None and get_rank() == 0: + infer_idxs[ntype] = test_idx + elif test_idx is None: logging.warning("%s does not contains test data, skip testing %s", ntype, ntype) - - if 'infer_mask' in g.nodes[ntype].data: - infer_idx = dgl.distributed.node_split(g.nodes[ntype].data['infer_mask'], - pb, ntype=ntype, force_even=True, - node_trainer_ids=node_trainer_ids) - # If there are inference data globally, we should add them to the dict. - if infer_idx is not None and dist_sum(len(infer_idx)) > 0: - infer_idxs[ntype] = infer_idx - elif infer_idx is None and get_rank() == 0: - logging.warning("%s does not contains inference data.", ntype) else: - # If 'infer_mask' is not specified, we will do inference on the entire edge set. - if get_rank() == 0: - logging.debug("%s doesn't have infer_mask. We run inference on all nodes.", - ntype) + # Inference only + # we will do inference on the entire edge set + logging.info("%s does not contains test_mask, skip testing %s. " + \ + "We will do inference on the entire node set.", ntype, ntype) infer_idx = dgl.distributed.node_split( - th.full((g.num_nodes(ntype),), True, dtype=th.bool), - pb, ntype=ntype, force_even=True, - node_trainer_ids=node_trainer_ids) + th.full((g.num_nodes(ntype),), True, dtype=th.bool), + pb, ntype=ntype, force_even=True, + node_trainer_ids=node_trainer_ids) infer_idxs[ntype] = infer_idx self._test_idxs = test_idxs self._infer_idxs = infer_idxs