diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index a5d58c88ce..17e8e92e96 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -1874,9 +1874,129 @@ def lp_edge_weight_for_loss(self): return None + def _get_predefined_negatives_per_etype(self, negatives): + if len(negatives) == 1 and \ + ":" not in negatives[0]: + # global feat_name + return negatives[0] + + # per edge type feature + negative_dict = {} + for negative in negatives: + negative_info = negative.split(":") + assert len(negative_info) == 2, \ + "negative dstnode information must be provided in format of " \ + f"src,relation,dst:feature_name, but get {negative}" + + etype = tuple(negative_info[0].split(",")) + assert len(etype) == 3, \ + f"Edge type must in format of (src,relation,dst), but get {etype}" + assert etype not in negative_dict, \ + f"You already specify the fixed negative of {etype} " \ + f"as {negative_dict[etype]}" + + negative_dict[etype] = negative_info[1] + return negative_dict + + @property + def train_etypes_negative_dstnode(self): + """ The list of canonical etypes that have hard negative edges + constructed by corrupting destination nodes. + + The format of the arguement should be: + train_etypes_negative_dstnode: + - src_type,rel_type0,dst_type:negative_nid_field + - src_type,rel_type1,dst_type:negative_nid_field + Each edge type can have different fields storing the hard negatives. + + or + train_etypes_negative_dstnode: + - negative_nid_field + All the edge types use the same filed storing the hard negatives. + """ + # pylint: disable=no-member + if hasattr(self, "_train_etypes_negative_dstnode"): + assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \ + "Hard negative only works with link prediction" + hard_negatives = self._train_etypes_negative_dstnode + return self._get_predefined_negatives_per_etype(hard_negatives) + + # By default fixed negative is not used + return None + + @property + def num_train_hard_negatives(self): + """ Number of hard negatives per edge type + + The format of the arguement should be: + num_train_hard_negatives: + - src_type,rel_type0,dst_type:num_negatives + - src_type,rel_type1,dst_type:num_negatives + Each edge type can have different number of hard negatives. + + or + num_train_hard_negatives: + - num_negatives + All the edge types use the same number of hard negatives. + """ + # pylint: disable=no-member + if hasattr(self, "_num_train_hard_negatives"): + assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \ + "Hard negative only works with link prediction" + num_negatives = self._num_train_hard_negatives + if len(num_negatives) == 1 and \ + ":" not in num_negatives[0]: + # global feat_name + return int(num_negatives[0]) + + # per edge type feature + num_hard_negative_dict = {} + for num_negative in num_negatives: + negative_info = num_negative.split(":") + assert len(negative_info) == 2, \ + "Number of train hard negative information must be provided in format of " \ + f"src,relation,dst:10, but get {num_negative}" + etype = tuple(negative_info[0].split(",")) + assert len(etype) == 3, \ + f"Edge type must in format of (src,relation,dst), but get {etype}" + assert etype not in num_hard_negative_dict, \ + f"You already specify the fixed negative of {etype} " \ + f"as {num_hard_negative_dict[etype]}" + + num_hard_negative_dict[etype] = int(negative_info[1]) + return num_hard_negative_dict + + return None + + @property + def eval_etypes_negative_dstnode(self): + """ The list of canonical etypes that have predefined negative edges + constructed by corrupting destination nodes. + + The format of the arguement should be: + eval_etypes_negative_dstnode: + - src_type,rel_type0,dst_type:negative_nid_field + - src_type,rel_type1,dst_type:negative_nid_field + Each edge type can have different fields storing the fixed negatives. + + or + eval_etypes_negative_dstnode: + - negative_nid_field + All the edge types use the same filed storing the fixed negatives. + """ + # pylint: disable=no-member + if hasattr(self, "_eval_etypes_negative_dstnode"): + assert self.task_type == BUILTIN_TASK_LINK_PREDICTION, \ + "Fixed negative only works with link prediction" + fixed_negatives = self._eval_etypes_negative_dstnode + return self._get_predefined_negatives_per_etype(fixed_negatives) + + # By default fixed negative is not used + return None + @property def train_etype(self): - """ The list of canonical etype that will be added as + """ The list of canonical etypes that will be added as training target with the target e type(s) If not provided, all edge types will be used as training target. @@ -2480,6 +2600,36 @@ def _add_link_prediction_args(parser): "metrics of each edge type to select the best model" "2) '--model-select-etype query,adds,item': Use the evaluation " "metric of the edge type (query,adds,item) to select the best model") + group.add_argument("--train-etypes-negative-dstnode", nargs='+', + type=str, default=argparse.SUPPRESS, + help="Edge feature field name for user defined negative destination ndoes " + "for training. The negative nodes are used to construct hard negative edges " + "by corrupting positive edges' destination nodes." + "It can be in following format: " + "1) '--train-etypes-negative-dstnode negative_nid_field', " + "if all edge types use the same negative destination node filed." + "2) '--train-etypes-negative-dstnode query,adds,asin:neg0 query,clicks,asin:neg1 ...'" + "Different edge types have different negative destination node fields." + ) + group.add_argument("--eval-etypes-negative-dstnode", nargs='+', + type=str, default=argparse.SUPPRESS, + help="Edge feature field name for user defined negative destination ndoes " + "for evaluation. The negative nodes are used to construct negative edges " + "by corrupting test edges' destination nodes." + "It can be in following format: " + "1) '--eval-etypes-negative-dstnode negative_nid_field', " + "if all edge types use the same negative destination node filed." + "2) '--eval-etypes-negative-dstnode query,adds,asin:neg0 query,clicks,asin:neg1 ...'" + "Different edge types have different negative destination node fields." + ) + group.add_argument("--num-train-hard-negatives", nargs='+', + type=str, default=argparse.SUPPRESS, + help="Number of hard negatives for each edge type during training." + "It can be in following format: " + "1) '--num-train-hard-negatives 10', " + "if all edge types use the same number of hard negatives." + "2) '--num-train-hard-negatives query,adds,asin:5 query,clicks,asin:10 ...'" + "Different edge types have different number of hard negatives.") return parser diff --git a/python/graphstorm/dataloading/__init__.py b/python/graphstorm/dataloading/__init__.py index 1aed57bb12..20abce548d 100644 --- a/python/graphstorm/dataloading/__init__.py +++ b/python/graphstorm/dataloading/__init__.py @@ -24,8 +24,9 @@ from .dataloading import GSgnnAllEtypeLinkPredictionDataLoader from .dataloading import GSgnnEdgeDataLoader from .dataloading import GSgnnNodeDataLoader, GSgnnNodeSemiSupDataLoader -from .dataloading import GSgnnLinkPredictionTestDataLoader -from .dataloading import GSgnnLinkPredictionJointTestDataLoader +from .dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from .dataloading import (FastGSgnnLinkPredictionDataLoader, FastGSgnnLPLocalJointNegDataLoader, FastGSgnnLPJointNegDataLoader, @@ -43,7 +44,8 @@ BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER, - BUILTIN_LP_LOCALJOINT_NEG_SAMPLER) + BUILTIN_LP_LOCALJOINT_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER) from .dataloading import BUILTIN_LP_ALL_ETYPE_UNIFORM_NEG_SAMPLER from .dataloading import BUILTIN_LP_ALL_ETYPE_JOINT_NEG_SAMPLER from .dataloading import (BUILTIN_FAST_LP_UNIFORM_NEG_SAMPLER, diff --git a/python/graphstorm/dataloading/dataloading.py b/python/graphstorm/dataloading/dataloading.py index f367219dde..fa22e5e9f4 100644 --- a/python/graphstorm/dataloading/dataloading.py +++ b/python/graphstorm/dataloading/dataloading.py @@ -33,7 +33,9 @@ JointLocalUniform, InbatchJointUniform, FastMultiLayerNeighborSampler, - DistributedFileSampler) + DistributedFileSampler, + GSHardEdgeDstNegativeSampler, + GSFixedEdgeDstNegativeSampler) from .utils import trim_data, modify_fanout_for_target_etype from .dataset import GSDistillData @@ -368,6 +370,7 @@ def fanout(self): BUILTIN_FAST_LP_JOINT_NEG_SAMPLER = 'fast_joint' BUILTIN_FAST_LP_LOCALUNIFORM_NEG_SAMPLER = 'fast_localuniform' BUILTIN_FAST_LP_LOCALJOINT_NEG_SAMPLER = 'fast_localjoint' +BUILTIN_LP_FIXED_NEG_SAMPLER = 'fixed' class GSgnnLinkPredictionDataLoaderBase(): """ The base class of link prediction dataloader. @@ -485,6 +488,10 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase): The node types that requires to construct node features. construct_feat_fanout : int The fanout required to construct node features. + edge_dst_negative_field: str or dict of str + The feature field(s) that store the hard negative edges for each edge type. + num_hard_negs: int or dict of int + The number of hard negatives per positive edge for each edge type Examples ------------ @@ -507,7 +514,9 @@ class GSgnnLinkPredictionDataLoader(GSgnnLinkPredictionDataLoaderBase): def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges, device='cpu', train_task=True, reverse_edge_types_map=None, exclude_training_targets=False, edge_mask_for_gnn_embeddings='train_mask', - construct_feat_ntype=None, construct_feat_fanout=5): + construct_feat_ntype=None, construct_feat_fanout=5, + edge_dst_negative_field=None, + num_hard_negs=None): super().__init__(dataset, target_idx, fanout) self._device = device for etype in target_idx: @@ -521,7 +530,9 @@ def __init__(self, dataset, target_idx, fanout, batch_size, num_negative_edges, reverse_edge_types_map=reverse_edge_types_map, edge_mask_for_gnn_embeddings=edge_mask_for_gnn_embeddings, construct_feat_ntype=construct_feat_ntype, - construct_feat_fanout=construct_feat_fanout) + construct_feat_fanout=construct_feat_fanout, + edge_dst_negative_field=edge_dst_negative_field, + num_hard_negs=num_hard_negs) def _prepare_negative_sampler(self, num_negative_edges): # the default negative sampler is uniform sampler @@ -532,7 +543,8 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, num_negative_edges, batch_size, device, train_task=True, exclude_training_targets=False, reverse_edge_types_map=None, edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None, - construct_feat_fanout=5): + construct_feat_fanout=5, edge_dst_negative_field=None, + num_hard_negs=None): g = dataset.g if construct_feat_ntype is None: construct_feat_ntype = [] @@ -553,6 +565,11 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, sampler = MultiLayerNeighborSamplerForReconstruct(sampler, dataset, construct_feat_ntype, construct_feat_fanout) negative_sampler = self._prepare_negative_sampler(num_negative_edges) + if edge_dst_negative_field is not None: + negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges, + edge_dst_negative_field, + negative_sampler, + num_hard_negs) # edge loader if train_task: @@ -639,7 +656,8 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, num_negative_edges, batch_size, device, train_task=True, exclude_training_targets=False, reverse_edge_types_map=None, edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None, - construct_feat_fanout=5): + construct_feat_fanout=5, edge_dst_negative_field=None, + num_hard_negs=None): g = dataset.g if construct_feat_ntype is None: construct_feat_ntype = [] @@ -660,6 +678,11 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, sampler = MultiLayerNeighborSamplerForReconstruct(sampler, dataset, construct_feat_ntype, construct_feat_fanout) negative_sampler = self._prepare_negative_sampler(num_negative_edges) + if edge_dst_negative_field is not None: + negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges, + edge_dst_negative_field, + negative_sampler, + num_hard_negs) # edge loader if train_task: @@ -881,7 +904,9 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, num_negative_edges, reverse_edge_types_map=None, edge_mask_for_gnn_embeddings=None, construct_feat_ntype=None, - construct_feat_fanout=5): + construct_feat_fanout=5, + edge_dst_negative_field=None, + num_hard_negs=None): g = dataset.g if construct_feat_ntype is None: construct_feat_ntype = [] @@ -898,6 +923,12 @@ def _prepare_dataloader(self, dataset, target_idxs, fanout, num_negative_edges, dataset, construct_feat_ntype, construct_feat_fanout) negative_sampler = self._prepare_negative_sampler(num_negative_edges) + if edge_dst_negative_field is not None: + negative_sampler = GSHardEdgeDstNegativeSampler(num_negative_edges, + edge_dst_negative_field, + negative_sampler, + num_hard_negs) + # edge loader if train_task: if isinstance(target_idxs, dict): @@ -973,8 +1004,8 @@ class GSgnnLinkPredictionTestDataLoader(): can save validation and test time. Default: None. """ - def __init__(self, dataset, target_idx, batch_size, num_negative_edges, fanout=None, - fixed_test_size=None): + def __init__(self, dataset, target_idx, batch_size, num_negative_edges, + fanout=None, fixed_test_size=None): self._data = dataset self._fanout = fanout for etype in target_idx: @@ -1059,6 +1090,55 @@ def _prepare_negative_sampler(self, num_negative_edges): self._neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER return negative_sampler +class GSgnnLinkPredictionPredefinedTestDataLoader(GSgnnLinkPredictionTestDataLoader): + """ Link prediction minibatch dataloader for validation and test + with predefined negatives. + + Parameters + ----------- + dataset: GSgnnEdgeData + The GraphStorm edge dataset + target_idx : dict of Tensors + The target edges for prediction + batch_size: int + Batch size + fanout: int + Evaluation fanout for computing node embedding + fixed_test_size: int + Fixed number of test data used in evaluation. + If it is none, use the whole testset. + When test is huge, using fixed_test_size + can save validation and test time. + Default: None. + fixed_edge_dst_negative_field: str or list of str + The feature field(s) that store the fixed negative set for each edge. + """ + def __init__(self, dataset, target_idx, batch_size, fixed_edge_dst_negative_field, + fanout=None, fixed_test_size=None): + self._fixed_edge_dst_negative_field = fixed_edge_dst_negative_field + super().__init__(dataset, target_idx, batch_size, + num_negative_edges=0, # num_negative_edges is not used + fanout=fanout, fixed_test_size=fixed_test_size) + + def _prepare_negative_sampler(self, _): + negative_sampler = GSFixedEdgeDstNegativeSampler(self._fixed_edge_dst_negative_field) + self._neg_sample_type = BUILTIN_LP_FIXED_NEG_SAMPLER + return negative_sampler + + def _next_data(self, etype): + """ Get postive edges for the next iteration for a specific edge type + """ + g = self._data.g + current_pos = self._current_pos[etype] + end_of_etype = current_pos + self._batch_size >= self._fixed_test_size[etype] + + pos_eids = self._target_idx[etype][current_pos:self._fixed_test_size[etype]] \ + if end_of_etype \ + else self._target_idx[etype][current_pos:current_pos+self._batch_size] + pos_neg_tuple = self._negative_sampler.gen_etype_neg_pairs(g, etype, pos_eids) + self._current_pos[etype] += self._batch_size + return pos_neg_tuple, end_of_etype + ################ Minibatch DataLoader (Node classification) ####################### class GSgnnNodeDataLoaderBase(): diff --git a/python/graphstorm/dataloading/sampler.py b/python/graphstorm/dataloading/sampler.py index 53b9d46794..a25170ce00 100644 --- a/python/graphstorm/dataloading/sampler.py +++ b/python/graphstorm/dataloading/sampler.py @@ -24,10 +24,13 @@ from dgl import backend as F from dgl import EID, NID from dgl.distributed import node_split -from dgl.dataloading.negative_sampler import Uniform +from dgl.dataloading.negative_sampler import (Uniform, + _BaseNegativeSampler) from dgl.dataloading import NeighborSampler from dgl.transforms import to_block +from ..utils import is_wholegraph + class LocalUniform(Uniform): """Negative sampler that randomly chooses negative destination nodes for each source node according to a uniform distribution. @@ -70,6 +73,180 @@ def _generate(self, g, eids, canonical_etype): dst = F.randint(shape, dtype, ctx, 0, self._local_neg_nids[vtype].shape[0]) return src, self._local_neg_nids[vtype][dst] +class GSHardEdgeDstNegativeSampler(_BaseNegativeSampler): + """ GraphStorm negative sampler that chooses negative destination nodes + from a fixed set to create negative edges. + + Parameters + ---------- + k: int + Number of negatives to sample. + dst_negative_field: str or dict of str + The field storing the hard negatives. + negative_sampler: sampler + The negative sampler to generate negatives + if there is not enough hard negatives. + num_hard_negs: int or dict of int + Number of hard negatives. + """ + def __init__(self, k, dst_negative_field, negative_sampler, num_hard_negs=None): + assert is_wholegraph() is False, \ + "Hard negative is not supported for WholeGraph." + self._dst_negative_field = dst_negative_field + self._k = k + self._negative_sampler = negative_sampler + self._num_hard_negs = num_hard_negs + + def _generate(self, g, eids, canonical_etype): + """ _generate() is called by DGL BaseNegativeSampler to generate negative pairs. + + See https://github.com/dmlc/dgl/blob/1.1.x/python/dgl/dataloading/negative_sampler.py#L7 + For more detials + """ + if isinstance(self._dst_negative_field, str): + dst_negative_field = self._dst_negative_field + elif canonical_etype in self._dst_negative_field: + dst_negative_field = self._dst_negative_field[canonical_etype] + else: + dst_negative_field = None + + if isinstance(self._num_hard_negs, int): + required_num_hard_neg = self._num_hard_negs + elif canonical_etype in self._num_hard_negs: + required_num_hard_neg = self._num_hard_negs[canonical_etype] + else: + required_num_hard_neg = 0 + + if dst_negative_field is None or required_num_hard_neg == 0: + # no hard negative, fallback to random negative + return self._negative_sampler._generate(g, eids, canonical_etype) + + hard_negatives = g.edges[canonical_etype].data[dst_negative_field][eids] + # It is possible that different edges may have different number of + # pre-defined negatives. For pre-defined negatives, the corresponding + # value in `hard_negatives` will be integers representing the node ids. + # For others, they will be -1s meaning there are missing fixed negatives. + if th.sum(hard_negatives == -1) == 0: + # Fast track, there is no -1 in hard_negatives + max_num_hard_neg = hard_negatives.shape[1] + neg_idx = th.randperm(max_num_hard_neg) + # shuffle the hard negatives + hard_negatives = hard_negatives[:,neg_idx] + + if required_num_hard_neg >= self._k and max_num_hard_neg >= self._k: + # All negative should be hard negative and + # there are enough hard negatives. + hard_negatives = hard_negatives[:,:self._k] + src, _ = g.find_edges(eids, etype=canonical_etype) + src = F.repeat(src, self._k, 0) + return src, hard_negatives.reshape((-1,)) + else: + if required_num_hard_neg < max_num_hard_neg: + # Only need required_num_hard_neg hard negatives. + hard_negatives = hard_negatives[:,:required_num_hard_neg] + num_hard_neg = required_num_hard_neg + else: + # There is not enough hard negative to fill required_num_hard_neg + num_hard_neg = max_num_hard_neg + + # There is not enough negatives + src, neg = self._negative_sampler._generate(g, eids, canonical_etype) + # replace random negatives with fixed negatives + neg = neg.reshape(-1, self._k) + neg[:,:num_hard_neg] = hard_negatives[:,:num_hard_neg] + return src, neg.reshape((-1,)) + else: + # slow track, we need to handle cases when there are -1s + hard_negatives, _ = th.sort(hard_negatives, dim=1, descending=True) + + src, neg = self._negative_sampler._generate(g, eids, canonical_etype) + for i in range(len(eids)): + hard_negative = hard_negatives[i] + # ignore -1s + hard_negative = hard_negative[hard_negative > -1] + max_num_hard_neg = hard_negative.shape[0] + hard_negative = hard_negative[th.randperm(max_num_hard_neg)] + + if required_num_hard_neg < max_num_hard_neg: + # Only need required_num_hard_neg hard negatives. + hard_negative = hard_negative[:required_num_hard_neg] + num_hard_neg = required_num_hard_neg + else: + num_hard_neg = max_num_hard_neg + + # replace random negatives with fixed negatives + neg[i*self._k:i*self._k + (num_hard_neg \ + if num_hard_neg < self._k else self._k)] = \ + hard_negative[:num_hard_neg if num_hard_neg < self._k else self._k] + return src, neg + +class GSFixedEdgeDstNegativeSampler(object): + """ GraphStorm negative sampler that uses fixed negative destination nodes + to create negative edges. + + GSFixedEdgeDstNegativeSampler only works with test dataloader. + + Parameters + ---------- + dst_negative_field: str or dict of str + The field storing the hard negatives. + """ + def __init__(self, dst_negative_field): + assert is_wholegraph() is False, \ + "Hard negative is not supported for WholeGraph." + self._dst_negative_field = dst_negative_field + + def gen_etype_neg_pairs(self, g, etype, pos_eids): + """ Returns negative examples associated with positive examples. + It only return dst negatives. + + This function is called by GSgnnLinkPredictionTestDataLoader._next_data() + to generate testing edges. + + Parameters + ---------- + g : DGLGraph + The graph. + pos_eids : (Tensor, Tensor) or dict[etype, (Tensor, Tensor)] + The positive edge ids. + + Returns + ------- + dict[etype, tuple(Tensor, Tensor Tensor, Tensor) + The returned [positive source, negative source, + postive destination, negatve destination] + tuples as pos-neg examples. + """ + def _gen_neg_pair(eids, canonical_etype): + src, pos_dst = g.find_edges(eids, etype=canonical_etype) + + if isinstance(self._dst_negative_field, str): + dst_negative_field = self._dst_negative_field + elif canonical_etype in self._dst_negative_field: + dst_negative_field = self._dst_negative_field[canonical_etype] + else: + raise RuntimeError(f"{etype} does not have pre-defined negatives") + + fixed_negatives = g.edges[canonical_etype].data[dst_negative_field][eids] + + # Users may use HardEdgeDstNegativeTransform + # to prepare the fixed negatives. + assert th.sum(fixed_negatives == -1) == 0, \ + "When using fixed negative destination nodes to construct testing edges," \ + "it is required that for each positive edge there are enough negative " \ + f"destination nodes. Please check the {dst_negative_field} feature " \ + f"of edge type {canonical_etype}" + + num_fixed_neg = fixed_negatives.shape[1] + logging.debug("The number of fixed negative is %d", num_fixed_neg) + return (src, None, pos_dst, fixed_negatives) + + assert etype in g.canonical_etypes, \ + f"Edge type {etype} does not exist in graph. Expecting an edge type in " \ + f"{g.canonical_etypes}, but get {etype}" + + return {etype: _gen_neg_pair(pos_eids, etype)} + class GlobalUniform(Uniform): """Negative sampler that randomly chooses negative destination nodes for each source node according to a uniform distribution. diff --git a/python/graphstorm/gconstruct/transform.py b/python/graphstorm/gconstruct/transform.py index 300922f027..0a2b82e622 100644 --- a/python/graphstorm/gconstruct/transform.py +++ b/python/graphstorm/gconstruct/transform.py @@ -1075,7 +1075,7 @@ class HardEdgeDstNegativeTransform(HardEdgeNegativeTransform): """ def set_target_etype(self, etype): - self._target_etype = etype + self._target_etype = tuple(etype) # target node type is destination node type. self._target_ntype = etype[2] diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index 427d15336c..b35b5f43c5 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -24,7 +24,8 @@ from .ngnn_mlp import NGNNMLP from .gs_layer import GSLayer, GSLayerNoParam from ..dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, - BUILTIN_LP_JOINT_NEG_SAMPLER) + BUILTIN_LP_JOINT_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER) from ..eval.utils import calc_distmult_pos_score, calc_dot_pos_score from ..eval.utils import calc_distmult_neg_head_score, calc_distmult_neg_tail_score @@ -628,7 +629,9 @@ def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device): neg_scores = [] if neg_src is not None: neg_src_emb = emb[utype][neg_src.reshape(-1,)].to(device) - if neg_sample_type == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if neg_sample_type in [BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER]: + # fixed negative sample is similar to uniform negative sample neg_src_emb = neg_src_emb.reshape( neg_src.shape[0], neg_src.shape[1], -1) pos_dst_emb = pos_dst_emb.reshape( @@ -654,7 +657,9 @@ def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device): neg_scores.append(neg_score) if neg_dst is not None: - if neg_sample_type == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if neg_sample_type in [BUILTIN_LP_UNIFORM_NEG_SAMPLER, \ + BUILTIN_LP_FIXED_NEG_SAMPLER]: + # fixed negative sample is similar to uniform negative sample neg_dst_emb = emb[vtype][neg_dst.reshape(-1,)].to(device) neg_dst_emb = neg_dst_emb.reshape( neg_dst.shape[0], neg_dst.shape[1], -1) @@ -881,7 +886,9 @@ def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device): if neg_src is not None: neg_src_emb = emb[utype][neg_src.reshape(-1,)] - if neg_sample_type == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if neg_sample_type in [BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER]: + # fixed negative sample is similar to uniform negative sample neg_src_emb = neg_src_emb.reshape(neg_src.shape[0], neg_src.shape[1], -1) # uniform sampled negative samples pos_dst_emb = pos_dst_emb.reshape( @@ -912,7 +919,9 @@ def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device): neg_scores.append(neg_score) if neg_dst is not None: - if neg_sample_type == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if neg_sample_type in [BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER]: + # fixed negative sample is similar to uniform negative sample neg_dst_emb = emb[vtype][neg_dst.reshape(-1,)] neg_dst_emb = neg_dst_emb.reshape(neg_dst.shape[0], neg_dst.shape[1], -1) # uniform sampled negative samples diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py index 143fd58e06..de18c5cc60 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lm_lp.py @@ -30,8 +30,9 @@ GSgnnLPInBatchJointNegDataLoader) from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, @@ -129,10 +130,14 @@ def main(config_args): raise ValueError('Unknown negative sampler') dataloader = dataloader_cls(train_data, train_data.train_idxs, [], config.batch_size, config.num_negative_edges, device, - train_task=True) + train_task=True, + edge_dst_negative_field=config.train_etypes_negative_dstnode, + num_hard_negs=config.num_train_hard_negatives) # TODO(zhengda) let's use full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader @@ -143,11 +148,19 @@ def main(config_args): val_dataloader = None test_dataloader = None if len(train_data.val_idxs) > 0: - val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, - config.eval_batch_size, config.num_negative_edges_eval) + if config.eval_etypes_negative_dstnode is not None: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, config.eval_etypes_negative_dstnode) + else: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, config.num_negative_edges_eval) if len(train_data.test_idxs) > 0: - test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, - config.eval_batch_size, config.num_negative_edges_eval) + if config.eval_etypes_negative_dstnode is not None: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, config.eval_etypes_negative_dstnode) + else: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, config.num_negative_edges_eval) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. diff --git a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py index 92a40dc737..1a843e18fa 100644 --- a/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py +++ b/python/graphstorm/run/gsgnn_lp/gsgnn_lp.py @@ -30,8 +30,9 @@ GSgnnLPInBatchJointNegDataLoader) from graphstorm.dataloading import GSgnnAllEtypeLPJointNegDataLoader from graphstorm.dataloading import GSgnnAllEtypeLinkPredictionDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, @@ -152,10 +153,14 @@ def main(config_args): reverse_edge_types_map=config.reverse_edge_types_map, exclude_training_targets=config.exclude_training_targets, construct_feat_ntype=config.construct_feat_ntype, - construct_feat_fanout=config.construct_feat_fanout) + construct_feat_fanout=config.construct_feat_fanout, + edge_dst_negative_field=config.train_etypes_negative_dstnode, + num_hard_negs=config.num_train_hard_negatives) # TODO(zhengda) let's use full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + if config.eval_etypes_negative_dstnode is not None: + test_dataloader_cls = GSgnnLinkPredictionPredefinedTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionTestDataLoader elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader @@ -166,13 +171,27 @@ def main(config_args): val_dataloader = None test_dataloader = None if len(train_data.val_idxs) > 0: - val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, - fixed_test_size=config.fixed_test_size) + if config.eval_etypes_negative_dstnode is not None: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout, + fixed_test_size=config.fixed_test_size) + else: + val_dataloader = test_dataloader_cls(train_data, train_data.val_idxs, + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size) if len(train_data.test_idxs) > 0: - test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, - config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, - fixed_test_size=config.fixed_test_size) + if config.eval_etypes_negative_dstnode is not None: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout, + fixed_test_size=config.fixed_test_size) + else: + test_dataloader = test_dataloader_cls(train_data, train_data.test_idxs, + config.eval_batch_size, config.num_negative_edges_eval, config.eval_fanout, + fixed_test_size=config.fixed_test_size) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py index 50a2e97acc..9ed19212c7 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_gnn.py @@ -22,8 +22,9 @@ from graphstorm.inference import GSgnnLinkPredictionInferrer from graphstorm.eval import GSgnnMrrLPEvaluator from graphstorm.dataloading import GSgnnEdgeInferData -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER from graphstorm.utils import setup_device, get_lm_ntypes @@ -58,19 +59,27 @@ def main(config_args): tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) # We only support full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: - test_dataloader_cls = GSgnnLinkPredictionTestDataLoader - elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: - test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader + if config.eval_etypes_negative_dstnode is not None: + # The negatives used in evaluation is fixed. + dataloader = GSgnnLinkPredictionPredefinedTestDataLoader( + infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode, + fanout=config.eval_fanout) else: - raise ValueError('Unknown test negative sampler.' - 'Supported test negative samplers include ' - f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') + if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader + else: + raise ValueError('Unknown test negative sampler.' + 'Supported test negative samplers include ' + f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') - dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, - batch_size=config.eval_batch_size, - num_negative_edges=config.num_negative_edges_eval, - fanout=config.eval_fanout) + dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + num_negative_edges=config.num_negative_edges_eval, + fanout=config.eval_fanout) infer.infer(infer_data, dataloader, save_embed_path=config.save_embed_path, edge_mask_for_gnn_embeddings=None if config.no_validation else \ diff --git a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py index e196d3fd83..bdf1becc7a 100644 --- a/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py +++ b/python/graphstorm/run/gsgnn_lp/lp_infer_lm.py @@ -23,8 +23,9 @@ from graphstorm.inference import GSgnnLinkPredictionInferrer from graphstorm.eval import GSgnnMrrLPEvaluator from graphstorm.dataloading import GSgnnEdgeInferData -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER from graphstorm.utils import setup_device @@ -58,18 +59,25 @@ def main(config_args): tracker = gs.create_builtin_task_tracker(config) infer.setup_task_tracker(tracker) # We only support full-graph inference for now. - if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: - test_dataloader_cls = GSgnnLinkPredictionTestDataLoader - elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: - test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader + if config.eval_etypes_negative_dstnode is not None: + # The negatives used in evaluation is fixed. + dataloader = GSgnnLinkPredictionPredefinedTestDataLoader( + infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + fixed_edge_dst_negative_field=config.eval_etypes_negative_dstnode) else: - raise ValueError('Unknown test negative sampler.' - 'Supported test negative samplers include ' - f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') + if config.eval_negative_sampler == BUILTIN_LP_UNIFORM_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionTestDataLoader + elif config.eval_negative_sampler == BUILTIN_LP_JOINT_NEG_SAMPLER: + test_dataloader_cls = GSgnnLinkPredictionJointTestDataLoader + else: + raise ValueError('Unknown test negative sampler.' + 'Supported test negative samplers include ' + f'[{BUILTIN_LP_UNIFORM_NEG_SAMPLER}, {BUILTIN_LP_JOINT_NEG_SAMPLER}]') - dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, - batch_size=config.eval_batch_size, - num_negative_edges=config.num_negative_edges_eval) + dataloader = test_dataloader_cls(infer_data, infer_data.test_idxs, + batch_size=config.eval_batch_size, + num_negative_edges=config.num_negative_edges_eval) # Preparing input layer for training or inference. # The input layer can pre-compute node features in the preparing step if needed. # For example pre-compute all BERT embeddings diff --git a/tests/end2end-tests/create_data.sh b/tests/end2end-tests/create_data.sh index 830fe46ba8..d1a6a65eb5 100644 --- a/tests/end2end-tests/create_data.sh +++ b/tests/end2end-tests/create_data.sh @@ -41,6 +41,15 @@ python3 -m graphstorm.gconstruct.construct_graph \ --graph-name movie-lens-100k \ --add-reverse-edges +# movielens link prediction - hard negative and fixed negative for inference +rm -Rf /data/movielen_100k_lp_train_val_hard_neg_1p_4t +python3 -m graphstorm.gconstruct.construct_graph \ + --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens_lp_hard.json \ + --num-processes 1 \ + --output-dir movielen_100k_lp_train_val_hard_neg_1p_4t \ + --graph-name movie-lens-100k \ + --add-reverse-edges + # movielens link prediction removing test mask rm -Rf /data/movielen_100k_lp_train_no_test_1p_4t cp -R /data/movielen_100k_lp_train_val_1p_4t /data/movielen_100k_lp_train_no_test_1p_4t diff --git a/tests/end2end-tests/data_gen/movielens_lp_hard.json b/tests/end2end-tests/data_gen/movielens_lp_hard.json new file mode 100644 index 0000000000..9a0804c6c2 --- /dev/null +++ b/tests/end2end-tests/data_gen/movielens_lp_hard.json @@ -0,0 +1,78 @@ +{ + "version": "gconstruct-v0.1", + "nodes": [ + { + "node_id_col": "id", + "node_type": "user", + "format": {"name": "hdf5"}, + "files": "/data/ml-100k/user.hdf5", + "features": [ + { + "feature_col": "feat" + } + ] + }, + { + "node_id_col": "id", + "node_type": "movie", + "format": {"name": "parquet"}, + "files": "/data/ml-100k/movie.parquet", + "features": [ + { + "feature_col": "title", + "transform": { + "name": "bert_hf", + "bert_model": "bert-base-uncased", + "max_seq_length": 16 + } + } + ] + } + ], + "edges": [ + { + "source_id_col": "src_id", + "dest_id_col": "dst_id", + "relation": ["user", "rating", "movie"], + "format": {"name": "parquet"}, + "files": "/data/ml-100k/edges.parquet", + "labels": [ + { + "task_type": "link_prediction", + "split_pct": [0.1, 0.1, 0.1] + } + ], + "features":[ + { + "feature_col": "rate", + "feature_name": "rate" + } + ] + }, + { + "relation": ["user", "rating", "movie"], + "format": {"name": "parquet"}, + "files": "/data/ml-100k/hard_neg.parquet", + "features": [ + { + "feature_col": "hard_0", + "feature_name": "hard_0", + "transform": {"name": "edge_dst_hard_negative"} + }, + { + "feature_col": "hard_1", + "feature_name": "hard_1", + "transform": { + "name": "edge_dst_hard_negative", + "separator": "," + } + }, + { + "feature_col": "fixed_eval", + "feature_name": "fixed_eval", + "transform": {"name": "edge_dst_hard_negative"} + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/end2end-tests/data_gen/process_movielens.py b/tests/end2end-tests/data_gen/process_movielens.py index 90fdcd1702..e17934d45f 100644 --- a/tests/end2end-tests/data_gen/process_movielens.py +++ b/tests/end2end-tests/data_gen/process_movielens.py @@ -90,6 +90,27 @@ def write_data_parquet(data, data_file): edge_data = {'src_id': edges[0], 'dst_id': edges[1], 'rate': edges[2]} write_data_parquet(edge_data, '/data/ml-100k/edges.parquet') +# generate hard negatives +num_movies = len(ids) +neg_movie_idx = np.random.randint(0, num_movies, (edges.shape[0], 5)) +neg_movie_0 = ids[neg_movie_idx] +neg_movie_1 = [] +for idx, neg_movie in enumerate(neg_movie_0): + if idx < 10: + neg_movie_1.append(list(neg_movie.astype(str))[0]) + else: + neg_movie_1.append(",".join(list(neg_movie.astype(str)))) +neg_movie_1 = np.array(neg_movie_1) +neg_movie_idx = np.random.randint(0, num_movies, (edges.shape[0], 10)) +neg_movie_2 = ids[neg_movie_idx] + +neg_edge_data = { + "hard_0": neg_movie_0, + "hard_1": neg_movie_1, + "fixed_eval": neg_movie_2 +} +write_data_parquet(neg_edge_data, '/data/ml-100k/hard_neg.parquet') + # generate synthetic user data with label user_labels = np.random.randint(11, size=feat.shape[0]) user_data = {'id': user['id'].values, 'feat': feat, 'occupation': user['occupation'], 'label': user_labels} diff --git a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh index 4ffd6bb261..dee7ea2038 100644 --- a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh @@ -525,4 +525,28 @@ python3 -m graphstorm.run.launch --workspace $GS_HOME/training_scripts/gsgnn_lp error_and_exit $? +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model, enough hard neg" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_hard_neg_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_hard_dot/ --save-model-frequency 1000 --train-etypes-negative-dstnode hard_0 --num-train-hard-negatives 4 --num-negative-edges 10 --target-etype user,rating,movie + +error_and_exit $? + +echo "**************dataset: Movielens, do inference on saved model, decoder: dot with fixed negative" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_hard_neg_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --restore-model-path /data/gsgnn_lp_ml_hard_dot/epoch-2/ --eval-etypes-negative-dstnode fixed_eval --eval-etype user,rating,movie + +error_and_exit $? + +rm -fr /data/gsgnn_lp_ml_hard_dot/* + +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, negative_sampler: joint, exclude_training_targets: true, save model, hard neg + random neg" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_hard_neg_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --exclude-training-targets True --reverse-edge-types-map user,rating,rating-rev,movie --save-model-path /data/gsgnn_lp_ml_hard_dot/ --save-model-frequency 1000 --train-etypes-negative-dstnode user,rating,movie:hard_1 --num-train-hard-negatives 5 --num-negative-edges 10 --target-etype user,rating,movie + +error_and_exit $? + +echo "**************dataset: Movielens, do inference on saved model, decoder: dot with fixed negative" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_hard_neg_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --restore-model-path /data/gsgnn_lp_ml_hard_dot/epoch-2/ --eval-etypes-negative-dstnode user,rating,movie:fixed_eval --eval-etype user,rating,movie + +error_and_exit $? + +rm -fr /data/gsgnn_lp_ml_hard_dot/* + rm -fr /tmp/* diff --git a/tests/unit-tests/test_dataloading.py b/tests/unit-tests/test_dataloading.py index 18613e2314..ee8dc50fad 100644 --- a/tests/unit-tests/test_dataloading.py +++ b/tests/unit-tests/test_dataloading.py @@ -49,14 +49,19 @@ FastGSgnnLPJointNegDataLoader, FastGSgnnLPLocalUniformNegDataLoader, FastGSgnnLPLocalJointNegDataLoader) -from graphstorm.dataloading import GSgnnLinkPredictionTestDataLoader -from graphstorm.dataloading import GSgnnLinkPredictionJointTestDataLoader +from graphstorm.dataloading import (GSgnnLinkPredictionTestDataLoader, + GSgnnLinkPredictionJointTestDataLoader, + GSgnnLinkPredictionPredefinedTestDataLoader) from graphstorm.dataloading import DistillDataloaderGenerator, DistillDataManager from graphstorm.dataloading import DistributedFileSampler -from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER -from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER +from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_JOINT_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER) from graphstorm.dataloading.sampler import InbatchJointUniform +from graphstorm.dataloading.sampler import GlobalUniform +from graphstorm.dataloading.sampler import (GSHardEdgeDstNegativeSampler, + GSFixedEdgeDstNegativeSampler) from graphstorm.dataloading.dataset import (prepare_batch_input, prepare_batch_edge_input) @@ -661,6 +666,59 @@ def test_GSgnnLinkPredictionTestDataLoader(batch_size, num_negative_edges): # after test pass, destroy all process group th.distributed.destroy_process_group() +@pytest.mark.parametrize("batch_size", [1, 10, 128]) +def test_GSgnnLinkPredictionPredefinedTestDataLoader(batch_size): + th.distributed.init_process_group(backend='gloo', + init_method='tcp://127.0.0.1:23456', + rank=0, + world_size=1) + test_etypes = [("n0", "r1", "n1"), ("n0", "r0", "n1")] + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + _, part_config = generate_dummy_dist_graph(graph_name='dummy', dirname=tmpdirname) + lp_data = GSgnnEdgeInferData(graph_name='dummy', part_config=part_config, + eval_etypes=test_etypes) + g = lp_data.g + g.edges[("n0", "r1", "n1")].data["neg"] = th.randint(g.num_nodes("n1"), + (g.num_edges(("n0", "r1", "n1")), 10)) + g.edges[("n0", "r0", "n1")].data["neg"] = th.randint(g.num_nodes("n1"), + (g.num_edges(("n0", "r0", "n1")), 10)) + + dataloader = GSgnnLinkPredictionPredefinedTestDataLoader( + lp_data, + target_idx=lp_data.infer_idxs, # use train edges as val or test edges + batch_size=batch_size, + fixed_edge_dst_negative_field="neg") + + total_edges = {etype: len(lp_data.infer_idxs[etype]) for etype in test_etypes} + num_pos_edges = {etype: 0 for etype in test_etypes} + for pos_neg_tuple, sample_type in dataloader: + assert sample_type == BUILTIN_LP_FIXED_NEG_SAMPLER + assert isinstance(pos_neg_tuple, dict) + assert len(pos_neg_tuple) == 1 + for canonical_etype, pos_neg in pos_neg_tuple.items(): + assert len(pos_neg) == 4 + pos_src, _, pos_dst, neg_dst = pos_neg + assert pos_src.shape == pos_dst.shape + assert pos_src.shape[0] == batch_size \ + if num_pos_edges[canonical_etype] + batch_size < total_edges[canonical_etype] \ + else total_edges[canonical_etype] - num_pos_edges[canonical_etype] + eid = lp_data.infer_idxs[canonical_etype][num_pos_edges[canonical_etype]: \ + num_pos_edges[canonical_etype]+batch_size] \ + if num_pos_edges[canonical_etype]+batch_size < total_edges[canonical_etype] \ + else lp_data.infer_idxs[canonical_etype] \ + [num_pos_edges[canonical_etype]:] + src, dst = g.find_edges(eid, etype=canonical_etype) + assert_equal(pos_src.numpy(), src.numpy()) + assert_equal(pos_dst.numpy(), dst.numpy()) + assert len(neg_dst.shape) == 2 + assert neg_dst.shape[1] == 10 + assert_equal(neg_dst.numpy(), g.edges[canonical_etype].data["neg"][eid].numpy()) + + num_pos_edges[canonical_etype] += batch_size + # after test pass, destroy all process group + th.distributed.destroy_process_group() + # initialize the torch distributed environment @pytest.mark.parametrize("batch_size", [1, 10, 128]) @pytest.mark.parametrize("num_negative_edges", [1, 16, 128]) @@ -1330,6 +1388,441 @@ def test_lp_dataloader_len(batch_size): device='cuda:0', train_task=True) assert len(dataloader) == len(list(dataloader)) +def _create_hard_neg_graph(num_nodes, num_negs): + etype0 = ("n0", "r0", "n1") + etype1 = ("n0", "r0", "n2") + etype2 = ("n0", "r0", "n3") + src = th.arange(num_nodes) + dst = th.arange(num_nodes) + # each edge has 4 pre-defined hard negatives + hard0 = th.randint(num_nodes, (num_nodes, 4)) + # each edge has 10 pre-defined hard negatives + hard1 = th.randint(num_nodes, (num_nodes, num_negs)) + # each edge has 20 pre-defined hard negatives + hard2 = th.randint(num_nodes, (num_nodes, num_negs*2)) # more hard negatives than num neg + g = dgl.heterograph({ + etype0: (src, dst), + etype1: (src, dst), + etype2: (src, dst), + }) + g.edges[etype0].data["hard_negative"] = hard0 + g.edges[etype1].data["hard_negative"] = hard1 + g.edges[etype2].data["hard_negative"] = hard2 + + return etype0, etype1, etype2, hard0, hard1, hard2, src, dst, g + +def test_hard_edge_dst_negative_sample_generate_complex_case(): + # test GSHardEdgeDstNegativeSampler._generate with slow track when not all the pos edges have enough hard negatives defined + num_nodes = 1000 + # test GSHardEdgeDstNegativeSampler._generate when all some pos edges do not have enough hard negatives defined + num_negs = 10 + etype0, etype1, etype2, hard0, hard1, hard2, src, _, g = _create_hard_neg_graph(num_nodes, num_negs) + + # not enough predefined hard negatives + # for hard0[0] and hard0[1] + hard0[0] = th.randperm(num_nodes)[:4] + hard0[0][-1] = -1 + hard0[0][-2] = -1 + hard0[1][-1] = -1 + + # not enough predefined hard negatives + # for hard0[0] and hard0[1] + hard1[0] = th.randperm(num_nodes)[:num_negs] + hard1[1] = th.randperm(num_nodes)[:num_negs] + hard1[0][-1] = -1 + hard1[0][-2] = -1 + hard1[1][-1] = -1 + + # not enough predefined hard negatives + # for hard0[0] and hard0[1] + hard2[0] = th.randperm(num_nodes)[:num_negs*2] + hard2[1] = th.randperm(num_nodes)[:num_negs*2] + hard2[0][-1] = -1 + hard2[0][-2] = -1 + hard2[1][-1] = -1 + + num_edges = 10 + eids = th.arange(num_edges) + def test_missing_hard_negs(neg_dst, num_hard_neg, hard_neg_data): + # hardx[0][-1] and hardx[0][-2] is -1, + # which means hardx[0] does not enough predefined negatives + # Random sample will be applied to -1s. + hard_neg_dst = neg_dst[0][:num_hard_neg] + hard_neg_rand_0 = hard_neg_dst[-1] + hard_neg_rand_1 = hard_neg_dst[-2] + hard_neg_dst = set(hard_neg_dst[:-2].tolist()) + rand_neg_dst = neg_dst[0][num_hard_neg:] + rand_neg_dst = set(rand_neg_dst.tolist()) + hard_neg_set = set(hard_neg_data[0].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) + assert len(rand_neg_dst) == 0 or \ + rand_neg_dst.issubset(hard_neg_set) is False + + rand_0_check = hard_neg_rand_0 not in hard_neg_set + rand_1_check = hard_neg_rand_1 not in hard_neg_set + + # hardx[1][-1] is -1, + # which means hardx[0] does not enough predefined negatives + # Random sample will be applied to -1s. + hard_neg_dst = neg_dst[1][:num_hard_neg] + hard_neg_rand_2 = hard_neg_dst[-1] + hard_neg_dst = set(hard_neg_dst[:-1].tolist()) + rand_neg_dst = neg_dst[1][num_hard_neg:] + rand_neg_dst = set(rand_neg_dst.tolist()) + hard_neg_set = set(hard_neg_data[1].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) + assert len(rand_neg_dst) == 0 or \ + rand_neg_dst.issubset(hard_neg_set) is False + + rand_2_check = hard_neg_rand_2 not in hard_neg_set + # The chance is very to to have rand_0_check, + # rand_1_check and rand_2_check be true at the same time + # The change is (4/1000)^3 + assert rand_0_check or rand_1_check or rand_2_check + + def check_less_hard_negs(hard_neg_sampler, target_etype, hard_neg_data, + num_hard_neg, check_missing_hard_neg): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + + if check_missing_hard_neg: + test_missing_hard_negs(neg_dst, num_hard_neg, hard_neg_data) + + start = 2 if check_missing_hard_neg else 0 + for i in range(start, num_edges): + hard_neg_dst = neg_dst[i][:num_hard_neg] + hard_neg_dst = set(hard_neg_dst.tolist()) + rand_neg_dst = neg_dst[i][num_hard_neg:] + rand_neg_dst = set(rand_neg_dst.tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) + assert rand_neg_dst.issubset(hard_neg_set) is False + + # case 1: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs > number of hard negatives required (2) + # 4. num_negs > total number of hard negatives + # provided (hard0 has 4 negatives for each node) + # 5. Each edge has enough hard negative even though some edges do not have enough (10) predefined negatives + # 6. slow track (-1 exists in hard neg feature) + # + # expected behavior: + # 1. Only 2 hard negatives are returned + # 2. Others will be random negatives + sampler = GlobalUniform(num_negs) + hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler, num_hard_negs=2) + check_less_hard_negs(hard_sampler, etype0, hard0, 2, check_missing_hard_neg=False) + + # Case 2: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs > number of hard negatives required (2) + # 4. num_negs == total number of hard negatives + # provided (hard1 has 10 negatives for each node) + # 5. Each edge has enough hard negative even though some edges do not have enough (8) predefined negatives + # 6. slow track (-1 exists in hard neg feature) + # + # expected behavior: + # 1. Only 2 hard negatives are returned + # 2. Others will be random negatives + check_less_hard_negs(hard_sampler, etype1, hard1, 2, check_missing_hard_neg=False) + + # Case 3: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs > number of hard negatives required (8) + # 4. number of hard negatives required (8) > number of hard negatives + # provided (hard0 has only 4 negatives for each node) + # 5.slow track (-1 exists in hard neg feature) + # + # expected behavior: + # 1. Only 4 hard negatives are returned + # 2. Others will be random negatives + # 3. eid 0 will have 2 more random negatives + # and eid 1 will have 1 more random negatives + hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler, num_hard_negs=8) + check_less_hard_negs(hard_sampler, etype0, hard0, 4, check_missing_hard_neg=True) + + # Case 4: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs == number of hard negatives required (10) + # 4. number of hard negatives required (8) == number of hard negatives + # provided (hard1 has 10 negatives for each node) + # 5.slow track (-1 exists in hard neg feature) + # + # expected behavior: + # 1. Equal negatives + def check_enough_hard_negs(hard_neg_sampler, target_etype, hard_neg_data): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + + test_missing_hard_negs(neg_dst, num_negs, hard_neg_data) + + for i in range(2, num_edges): + hard_neg_dst = set(neg_dst[i].tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst == hard_neg_set + + hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler, num_hard_negs=num_negs) + check_enough_hard_negs(hard_sampler, etype1, hard1) + + # Case 5: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs == number of hard negatives required (10) + # 4. number of hard negatives required (8) < number of hard negatives + # provided (hard2 has 20 negatives for each node) + # 5.slow track (-1 exists in hard neg feature) + # + # expected behavior: + # 1. hard negatives will be a subset of hard2 + def check_more_hard_negs(hard_neg_sampler, target_etype, hard_neg_data): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + for i in range(num_edges): + hard_neg_dst = set(neg_dst[i].tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) + check_more_hard_negs(hard_sampler, etype2, hard2) + +def test_hard_edge_dst_negative_sample_generate(): + # test GSHardEdgeDstNegativeSampler._generate with fast track when all pos edges have enough hard negatives defined + num_nodes = 100 + num_negs = 10 + etype0, etype1, etype2, hard0, hard1, hard2, src, _, g = _create_hard_neg_graph(num_nodes, num_negs) + + num_edges = 10 + eids = th.arange(num_edges) + def check_less_hard_negs(hard_neg_sampler, target_etype, hard_neg_data, num_hard_neg): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + for i in range(num_edges): + hard_neg_dst = neg_dst[i][:num_hard_neg] + hard_neg_dst = set(hard_neg_dst.tolist()) + rand_neg_dst = neg_dst[i][num_hard_neg:] + rand_neg_dst = set(rand_neg_dst.tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) + assert rand_neg_dst.issubset(hard_neg_set) is False + + # case 1: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs > number of hard negatives required (2) + # 4. num_negs > total number of hard negatives + # provided (hard0 has 4 negatives for each node) + # 5. fast track + # + # expected behavior: + # 1. Only 2 hard negatives are returned + # 2. Others will be random negatives + sampler = GlobalUniform(num_negs) + hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler, num_hard_negs=2) + check_less_hard_negs(hard_sampler, etype0, hard0, 2) + + # Case 2: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs > number of hard negatives required (2) + # 4. num_negs == total number of hard negatives + # provided (hard1 has 10 negatives for each node) + # 5. fast track + # + # expected behavior: + # 1. Only 2 hard negatives are returned + # 2. Others will be random negatives + check_less_hard_negs(hard_sampler, etype1, hard1, 2) + + # Case 3: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs > number of hard negatives required (8) + # 4. number of hard negatives required (8) > number of hard negatives + # provided (hard0 has only 4 negatives for each node) + # 5.fast track + # + # expected behavior: + # 1. Only 4 hard negatives are returned + # 2. Others will be random negatives + hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler, num_hard_negs=8) + check_less_hard_negs(hard_sampler, etype0, hard0, 4) + + # Case 4: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs == number of hard negatives required (10) + # 4. number of hard negatives required (10) == number of hard negatives + # provided (hard1 has 10 negatives for each node) + # 5.fast track + # + # expected behavior: + # 1. Equal negatives + def check_enough_hard_negs(hard_neg_sampler, target_etype, hard_neg_data): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + for i in range(num_edges): + hard_neg_dst = set(neg_dst[i].tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst == hard_neg_set + hard_sampler = GSHardEdgeDstNegativeSampler(num_negs, "hard_negative", sampler, num_hard_negs=num_negs) + check_enough_hard_negs(hard_sampler, etype1, hard1) + + # Case 5: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs == number of hard negatives required (10) + # 4. number of hard negatives required (8) < number of hard negatives + # provided (hard2 has 20 negatives for each node) + # 5.fast track + # + # expected behavior: + # 1. hard negatives will be a subset of hard2 + def check_more_hard_negs(hard_neg_sampler, target_etype, hard_neg_data): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + for i in range(num_edges): + hard_neg_dst = set(neg_dst[i].tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) + check_more_hard_negs(hard_sampler, etype2, hard2) + + # Case 6: + # hard_negative field is dict + # num_hard_neg is dict + # 3. num_negs > number of hard negatives required (2) + # 4. num_negs > total number of hard negatives + # provided (hard0 has 4 negatives for each node) + # 5. fast track + # + # expected behavior: + # 1. Only 2 hard negatives are returned + # 2. Others will be random negatives + hard_sampler = GSHardEdgeDstNegativeSampler( + num_negs, + {etype0: "hard_negative", + etype1: "hard_negative", + etype2: "hard_negative"}, + sampler, + {etype0: 2, + etype1: 2, + etype2: 10}) + check_less_hard_negs(hard_sampler, etype0, hard0, 2) + + # Case 7: + # hard_negative field is dict + # num_hard_neg is dict + # 3. num_negs > number of hard negatives required (2) + # 4. num_negs == total number of hard negatives + # provided (hard1 has 10 negatives for each node) + # 5. fast track + # + # expected behavior: + # 1. Only 2 hard negatives are returned + # 2. Others will be random negatives + check_less_hard_negs(hard_sampler, etype1, hard1, 2) + + # Case 8: + # 1. hard_negative field is string + # 2. num_hard_neg is int + # 3. num_negs == number of hard negatives required (10) + # 4. number of hard negatives required (8) < number of hard negatives + # provided (hard2 has 20 negatives for each node) + # 5.fast track + # + # expected behavior: + # 1. hard negatives will be a subset of hard2 + check_more_hard_negs(hard_sampler, etype2, hard2) + + def check_none_hard_negs(hard_neg_sampler, target_etype, hard_neg_data): + neg_src, neg_dst = hard_neg_sampler._generate(g, eids, target_etype) + assert len(neg_src) == num_edges * num_negs + assert len(neg_dst) == num_edges * num_negs + assert_equal(th.repeat_interleave(src[:10], num_negs, 0).numpy(), neg_src.numpy()) + neg_dst = neg_dst.reshape(num_edges, num_negs) + for i in range(num_edges): + hard_neg_dst = set(neg_dst[i].tolist()) + hard_neg_set = set(hard_neg_data[i].tolist()) + assert hard_neg_dst.issubset(hard_neg_set) is False + # Case 9: + # dst_negative_field is not provided + hard_sampler = GSHardEdgeDstNegativeSampler( + num_negs, {}, sampler, 2) + check_none_hard_negs(hard_sampler, etype2, hard2) + + # Case 10: + # num_hard_negs is not provided + hard_sampler = GSHardEdgeDstNegativeSampler( + num_negs, "hard_negative", sampler, {}) + check_none_hard_negs(hard_sampler, etype2, hard2) + +def test_edge_fixed_dst_negative_sample_gen_neg_pairs(): + # test GSHardEdgeDstNegativeSampler.gen_neg_pairs with fast track when all edges have enough predefined negatives + num_nodes = 1000 + # test GSHardEdgeDstNegativeSampler._generate when all some pos edges do not have enough hard negatives defined + num_negs = 10 + etype0, etype1, etype2, hard0, hard1, hard2, src, dst, g = _create_hard_neg_graph(num_nodes, num_negs) + + num_edges = 10 + + def check_fixed_negs(pos_neg_tuple, etype, hard_neg_data): + neg_src, _, pos_dst, neg_dst = pos_neg_tuple[etype] + + assert len(neg_src) == num_edges + assert len(pos_dst) == num_edges + assert neg_dst.shape[0] == num_edges + assert_equal(src[:10].numpy(), neg_src.numpy()) + assert_equal(dst[:10].numpy(), pos_dst.numpy()) + + assert_equal(hard_neg_data[:10].numpy(), neg_dst.numpy()) + + hard_sampler = GSFixedEdgeDstNegativeSampler("hard_negative") + pos_neg_tuple = hard_sampler.gen_etype_neg_pairs(g, etype0, th.arange(10)) + check_fixed_negs(pos_neg_tuple, etype0, hard0) + pos_neg_tuple = hard_sampler.gen_etype_neg_pairs(g, etype1, th.arange(10)) + check_fixed_negs(pos_neg_tuple, etype1, hard1) + pos_neg_tuple = hard_sampler.gen_etype_neg_pairs(g, etype2, th.arange(10)) + check_fixed_negs(pos_neg_tuple, etype2, hard2) + + hard_sampler = GSFixedEdgeDstNegativeSampler({etype0: "hard_negative", + etype1: "hard_negative", + etype2: "hard_negative"}) + pos_neg_tuple = hard_sampler.gen_etype_neg_pairs(g, etype0, th.arange(10)) + check_fixed_negs(pos_neg_tuple, etype0, hard0) + pos_neg_tuple = hard_sampler.gen_etype_neg_pairs(g, etype1, th.arange(10)) + check_fixed_negs(pos_neg_tuple, etype1, hard1) + pos_neg_tuple = hard_sampler.gen_etype_neg_pairs(g, etype2, th.arange(10)) + check_fixed_negs(pos_neg_tuple, etype2, hard2) + + # each positive edge should have enough fixed negatives + hard0[0][-1] = -1 + fail = False + try: + pos_neg_tuple = hard_sampler.gen_neg_pairs(g, etype0, th.arange(10)) + except: + fail = True + assert fail + + @pytest.mark.parametrize("num_pos", [2, 10]) @pytest.mark.parametrize("num_neg", [5, 20]) def test_inbatch_joint_neg_sampler(num_pos, num_neg): @@ -1356,6 +1849,11 @@ def test_inbatch_joint_neg_sampler(num_pos, num_neg): if __name__ == '__main__': + test_GSgnnLinkPredictionPredefinedTestDataLoader(1) + test_GSgnnLinkPredictionPredefinedTestDataLoader(10) + test_edge_fixed_dst_negative_sample_gen_neg_pairs() + test_hard_edge_dst_negative_sample_generate_complex_case() + test_hard_edge_dst_negative_sample_generate() test_inbatch_joint_neg_sampler(10, 20) test_np_dataloader_len(11)