diff --git a/.github/workflow_scripts/e2e_gb_check.sh b/.github/workflow_scripts/e2e_gb_check.sh index 8ea53a4824..3f7fa80280 100644 --- a/.github/workflow_scripts/e2e_gb_check.sh +++ b/.github/workflow_scripts/e2e_gb_check.sh @@ -8,6 +8,6 @@ GS_HOME=$(pwd) # Install graphstorm from checked out code pip3 install "$GS_HOME" --upgrade -bash ./tests/end2end-tests/setup.sh bash ./tests/end2end-tests/create_data.sh bash ./tests/end2end-tests/graphbolt-gs-integration/graphbolt-graph-construction.sh +bash ./tests/end2end-tests/graphbolt-gs-integration/graphbolt-training-inference.sh diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 7b200bc71b..beb3b07826 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -191,6 +191,7 @@ jobs: uses: aws-actions/configure-aws-credentials@v1 with: role-to-assume: arn:aws:iam::698571788627:role/github-oidc-role + role-duration-seconds: 14400 aws-region: us-east-1 - name: Checkout repository uses: actions/checkout@v3 diff --git a/docs/source/advanced/link-prediction.rst b/docs/source/advanced/link-prediction.rst index 37fb04a5f1..10d7450cb6 100644 --- a/docs/source/advanced/link-prediction.rst +++ b/docs/source/advanced/link-prediction.rst @@ -12,8 +12,8 @@ Optimizing model performance ---------------------------- GraphStorm incorporates three ways of improving model performance of link prediction. Firstly, GraphStorm avoids information leak in model training. -Secondly, to better handle heterogeneous graphs, GraphStorm provides three ways -to compute link prediction scores: dot product, DistMult and RotatE. +Secondly, to better handle heterogeneous graphs, GraphStorm provides four ways +to compute link prediction scores: dot product, DistMult, TransE, and RotatE. Thirdly, GraphStorm provides two options to compute training losses, i.e., cross entropy loss and contrastive loss. The following sub-sections provide more details. @@ -32,7 +32,7 @@ GraphStorm provides supports to avoid theses problems: Computing Link Prediction Scores ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -GraphStorm provides three ways to compute link prediction scores: Dot Product, DistMult and RotatE. +GraphStorm provides four ways to compute link prediction scores: Dot Product, DistMult, TransE, and RotatE. * **Dot Product**: The Dot Product score function is as: @@ -53,7 +53,21 @@ GraphStorm provides three ways to compute link prediction scores: Dot Product, D The ``relation_emb`` values are initialized from a uniform distribution within the range of ``(-gamma/hidden_size, gamma/hidden_size)``, where ``gamma`` and ``hidden_size`` are hyperparameters defined in - :ref:`Model Configurations`。 + :ref:`Model Configurations`. + +* **TransE**: The TransE score function is as: + + .. math:: + score = gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\| + + where the ``head_emb`` is the node embedding of the head node, + the ``tail_emb`` is the node embedding of the tail node, + the ``relation_emb`` is the relation embedding of the specific edge type. + The ``relation_emb`` values are initialized from a uniform distribution + within the range of ``(-gamma/(hidden_size/2), gamma/(hidden_size/2))``, + where ``gamma`` and ``hidden_size`` are hyperparameters defined in + :ref:`Model Configurations`. + To learn more information about TransE, please refer to `the DGLKE doc `__. * **RotatE**: The RotatE score function is as: diff --git a/docs/source/api/references/graphstorm.model.rst b/docs/source/api/references/graphstorm.model.rst index 18977eddbd..4b1a049d3c 100644 --- a/docs/source/api/references/graphstorm.model.rst +++ b/docs/source/api/references/graphstorm.model.rst @@ -101,3 +101,7 @@ Decoder Layer LinkPredictContrastiveDistMultDecoder LinkPredictRotatEDecoder LinkPredictContrastiveRotatEDecoder + LinkPredictWeightedRotatEDecoder + LinkPredictTransEDecoder + LinkPredictContrastiveTransEDecoder + LinkPredictWeightedTransEDecoder diff --git a/docs/source/cli/model-training-inference/configuration-run.rst b/docs/source/cli/model-training-inference/configuration-run.rst index facf834c7b..cc96bae17c 100644 --- a/docs/source/cli/model-training-inference/configuration-run.rst +++ b/docs/source/cli/model-training-inference/configuration-run.rst @@ -482,12 +482,12 @@ Link Prediction Task - Yaml: ``num_negative_edges_eval: 1000`` - Argument: ``--num-negative-edges-eval 1000`` - Default value: ``1000`` -- **lp_decoder_type**: Set the decoder type for loss function in Link Prediction tasks. Currently GraphStorm support ``dot_product``, ``distmult`` and ``rotate``. +- **lp_decoder_type**: Set the decoder type for loss function in Link Prediction tasks. Currently GraphStorm support ``dot_product``, ``distmult``, ``rotate``, ``transe_l1``, and ``transe_l2``. - Yaml: ``lp_decoder_type: dot_product`` - Argument: ``--lp-decoder-type dot_product`` - Default value: ``distmult`` -- **gamma**: Set the value of the hyperparameter denoted by the symbol gamma. Gamma is used in the following cases: i/ focal loss for binary classification ii/ DistMult score function for link prediction and iii/ RotatE score function for link prediction. +- **gamma**: Set the value of the hyperparameter denoted by the symbol gamma. Gamma is used in the following cases: i/ focal loss for binary classification ii/ DistMult score function for link prediction, iii/ TransE score function for link prediction, and iv/ RotatE score function for link prediction. - Yaml: ``gamma: 10.0`` - Argument: ``--gamma 10.0`` @@ -586,4 +586,4 @@ GraphStorm provides a set of parameters to control GNN distillation. - Yaml: ``max_seq_len: 1024`` - Argument: ``--max-seq-len 1024`` - - Default value: ``1024`` + - Default value: ``1024`` \ No newline at end of file diff --git a/python/graphstorm/config/__init__.py b/python/graphstorm/config/__init__.py index e79d368cad..6cc5e71eda 100644 --- a/python/graphstorm/config/__init__.py +++ b/python/graphstorm/config/__init__.py @@ -31,7 +31,9 @@ from .config import (BUILTIN_LP_DOT_DECODER, BUILTIN_LP_DISTMULT_DECODER, - BUILTIN_LP_ROTATE_DECODER) + BUILTIN_LP_ROTATE_DECODER, + BUILTIN_LP_TRANSE_L1_DECODER, + BUILTIN_LP_TRANSE_L2_DECODER) from .config import SUPPORTED_LP_DECODER from .config import (GRAPHSTORM_MODEL_EMBED_LAYER, diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 441bc570a4..7eab1b2ea8 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -83,7 +83,7 @@ def get_argument_parser(): arugments in GraphStorm launch CLIs. Specifically, it will parses yaml config file first, and then parses arguments to overwrite parameters defined in the yaml file or add new parameters. - + This ``get_argument_parser()`` is also useful when users want to convert customized models to use GraphStorm CLIs. @@ -166,7 +166,7 @@ def get_argument_parser(): # pylint: disable=no-member class GSConfig: """GSgnn configuration class. - + GSConfig contains all GraphStorm model training and inference configurations, which can either be loaded from a yaml file specified in the ``--cf`` argument, or from CLI arguments. """ @@ -1223,9 +1223,9 @@ def edge_feat_name(self): @property def node_feat_name(self): """ User defined node feature name. Default is None. - + It can be in following format: - + - ``feat_name``: global feature name, if a node has node feature, the corresponding feature name is . - ``"ntype0:feat0","ntype1:feat0,feat1",...``: different node types have different @@ -1291,16 +1291,16 @@ def _check_fanout(self, fanout, fot_name): def fanout(self): """ The fanouts of GNN layers. The values of fanouts must be integers larger than 0. The number of fanouts must equal to ``num_layers``. Must provide. - - It accepts two formats: - + + It accepts two formats: + - ``20,10``, which defines the number of neighbors to sample per edge type for each GNN layer with the i_th element being the fanout for the ith GNN layer. - + - "etype2:20@etype3:20@etype1:10,etype2:10@etype3:4@etype1:2", which defines the numbers of neighbors to sample for different edge types for each GNN layers - with the i_th element being the fanout for the i_th GNN layer. + with the i_th element being the fanout for the i_th GNN layer. """ # pylint: disable=no-member if self.model_encoder_type in BUILTIN_GNN_ENCODER: @@ -1440,7 +1440,7 @@ def use_mini_batch_infer(self): @property def gnn_norm(self): - """ Normalization method for GNN layers. Options include ``batch`` or ``layer``. + """ Normalization method for GNN layers. Options include ``batch`` or ``layer``. Default is None. """ # pylint: disable=no-member @@ -1616,7 +1616,7 @@ def dropout(self): @property # pylint: disable=invalid-name def lr(self): - """ Learning rate for dense parameters of input encoders, model encoders, + """ Learning rate for dense parameters of input encoders, model encoders, and decoders. Must provide. """ assert hasattr(self, "_lr"), "Learning rate must be specified" @@ -1825,7 +1825,7 @@ def early_stop_rounds(self): @property def early_stop_strategy(self): - """ The strategy used to decide if stop training early. GraphStorm supports two + """ The strategy used to decide if stop training early. GraphStorm supports two strategies: 1) ``consecutive_increase``, and 2) ``average_increase``. Default is ``average_increase``. """ @@ -1843,7 +1843,7 @@ def early_stop_strategy(self): @property def use_early_stop(self): - """ Whether to use early stopping during training. Default is False. + """ Whether to use early stopping during training. Default is False. """ # pylint: disable=no-member if hasattr(self, "_use_early_stop"): @@ -1954,7 +1954,7 @@ def check_multilabel(multilabel): def multilabel_weights(self): """Used to specify label weight of each class in a multi-label classification task. It is feed into ``th.nn.BCEWithLogitsLoss`` as ``pos_weight``. - + The weights should be in the following format 0.1,0.2,0.3,0.1,0.0, ... Default is None. """ @@ -2182,7 +2182,7 @@ def remove_target_edge_type(self): Default is True. If set to True, Graphstorm will set the fanout of training target edge - type as zero. This is only used with edge classification. + type as zero. This is only used with edge classification. If the edge classification is to predict the existence of an edge between two nodes, GraphStorm should remove the target edge in the message passing to avoid information leak. If it's to predict some attributes associated with @@ -2205,7 +2205,7 @@ def remove_target_edge_type(self): @property def decoder_type(self): - """ The type of edge clasification or regression decoders. Built-in decoders include + """ The type of edge clasification or regression decoders. Built-in decoders include ``DenseBiDecoder`` and ``MLPDecoder``. Default is ``DenseBiDecoder``. """ # pylint: disable=no-member @@ -2265,7 +2265,7 @@ def decoder_edge_feat(self): ### Link Prediction specific ### @property def train_negative_sampler(self): - """ The negative sampler used for link prediction training. + """ The negative sampler used for link prediction training. Built-in samplers include ``uniform``, ``joint``, ``localuniform``, ``all_etype_uniform`` and ``all_etype_joint``. Default is ``uniform``. """ @@ -2276,7 +2276,7 @@ def train_negative_sampler(self): @property def eval_negative_sampler(self): - """ The negative sampler used for link prediction training. + """ The negative sampler used for link prediction training. Built-in samplers include ``uniform``, ``joint``, ``localuniform``, ``all_etype_uniform`` and ``all_etype_joint``. Default is ``joint``. """ @@ -2316,8 +2316,8 @@ def num_negative_edges_eval(self): @property def lp_decoder_type(self): """ The decoder type for loss function in link prediction tasks. - Currently GraphStorm supports ``dot_product``, ``distmult`` and ``rotate``. - Default is ``distmult``. + Currently GraphStorm supports ``dot_product``, ``distmult``, + ``transe`` (``transe_l1`` and ``transe_l2``), and ``rotate``. Default is ``distmult``. """ # pylint: disable=no-member if hasattr(self, "_lp_decoder_type"): @@ -2379,10 +2379,10 @@ def lp_edge_weight_for_loss(self): positive edge loss for link prediction tasks. Default is None. The edge_weight can be in following format: - + - ``weight_name``: global weight name, if an edge has weight, the corresponding weight name is ``weight_name``. - + - ``"src0,rel0,dst0:weight0","src0,rel0,dst0:weight1",...``: different edge types have different edge weights. """ @@ -2450,7 +2450,7 @@ def _get_predefined_negatives_per_etype(self, negatives): @property def train_etypes_negative_dstnode(self): - """ The list of canonical edge types that have hard negative edges + """ The list of canonical edge types that have hard negative edges constructed by corrupting destination nodes during training. For each edge type to use different fields to store the hard negatives, @@ -2458,13 +2458,13 @@ def train_etypes_negative_dstnode(self): .. code:: json - train_etypes_negative_dstnode: + train_etypes_negative_dstnode: - src_type,rel_type0,dst_type:negative_nid_field - src_type,rel_type1,dst_type:negative_nid_field - + or, for all edge types to use the same field to store the hard negatives, the format of the arguement is: - + .. code:: json train_etypes_negative_dstnode: @@ -2482,7 +2482,7 @@ def train_etypes_negative_dstnode(self): @property def num_train_hard_negatives(self): - """ Number of hard negatives to sample for each edge type during training. + """ Number of hard negatives to sample for each edge type during training. Default is None. For each edge type to have a number of hard negatives, @@ -2496,7 +2496,7 @@ def num_train_hard_negatives(self): or, for all edge types to have the same number of hard negatives, the format of the arguement is: - + .. code:: json num_train_hard_negatives: @@ -2533,7 +2533,7 @@ def num_train_hard_negatives(self): @property def eval_etypes_negative_dstnode(self): - """ The list of canonical edge types that have hard negative edges + """ The list of canonical edge types that have hard negative edges constructed by corrupting destination nodes during evaluation. For each edge type to use different fields to store the hard negatives, @@ -2541,13 +2541,13 @@ def eval_etypes_negative_dstnode(self): .. code:: json - eval_etypes_negative_dstnode: + eval_etypes_negative_dstnode: - src_type,rel_type0,dst_type:negative_nid_field - src_type,rel_type1,dst_type:negative_nid_field - + or, for all edge types to use the same field to store the hard negatives, the format of the arguement is: - + .. code:: json eval_etypes_negative_dstnode: @@ -2565,7 +2565,7 @@ def eval_etypes_negative_dstnode(self): @property def train_etype(self): - """ The list of canonical edge types that will be added as training target. + """ The list of canonical edge types that will be added as training target. If not provided, all edge types will be used as training target. A canonical edge type should be formatted as ``src_node_type,relation_type,dst_node_type``. """ @@ -2582,7 +2582,7 @@ def train_etype(self): @property def eval_etype(self): - """ The list of canonical edge types that will be added as evaluation target. + """ The list of canonical edge types that will be added as evaluation target. If not provided, all edge types will be used as evaluation target. A canonical edge type should be formatted as ``src_node_type,relation_type,dst_node_type``. """ @@ -2638,7 +2638,7 @@ def alpha(self): @property def class_loss_func(self): - """ Classification loss function. Builtin loss functions include + """ Classification loss function. Builtin loss functions include ``cross_entropy`` and ``focal``. Default is ``cross_entropy``. """ # pylint: disable=no-member @@ -2652,7 +2652,7 @@ def class_loss_func(self): @property def lp_loss_func(self): - """ Link prediction loss function. Builtin loss functions include + """ Link prediction loss function. Builtin loss functions include ``cross_entropy`` and ``contrastive``. Default is ``cross_entropy``. """ # pylint: disable=no-member diff --git a/python/graphstorm/config/config.py b/python/graphstorm/config/config.py index a31c12be53..a2cfda3bda 100644 --- a/python/graphstorm/config/config.py +++ b/python/graphstorm/config/config.py @@ -81,10 +81,14 @@ BUILTIN_LP_DOT_DECODER = "dot_product" BUILTIN_LP_DISTMULT_DECODER = "distmult" BUILTIN_LP_ROTATE_DECODER = "rotate" +BUILTIN_LP_TRANSE_L1_DECODER = "transe_l1" +BUILTIN_LP_TRANSE_L2_DECODER = "transe_l2" SUPPORTED_LP_DECODER = [BUILTIN_LP_DOT_DECODER, BUILTIN_LP_DISTMULT_DECODER, - BUILTIN_LP_ROTATE_DECODER] + BUILTIN_LP_ROTATE_DECODER, + BUILTIN_LP_TRANSE_L1_DECODER, + BUILTIN_LP_TRANSE_L2_DECODER] ################ Task info data classes ############################ def get_mttask_id(task_type, ntype=None, etype=None, label=None): diff --git a/python/graphstorm/eval/utils.py b/python/graphstorm/eval/utils.py index 40cd983b88..f71797a51f 100644 --- a/python/graphstorm/eval/utils.py +++ b/python/graphstorm/eval/utils.py @@ -247,6 +247,8 @@ def calc_rotate_pos_score(h_emb, t_emb, r_emb, rel_emb_init, gamma, device=None) The initial value used to bound the relation embedding initialization. gamma: float The gamma value used for shifting the optimization target. + device: th.device + Device to run the computation. Return ------ @@ -297,6 +299,8 @@ def calc_rotate_neg_head_score(heads, tails, r_emb, num_chunks, The initial value used to bound the relation embedding initialization. gamma: float The gamma value used for shifting the optimization target. + device: th.device + Device to run the computation. Return ------ @@ -349,6 +353,8 @@ def calc_rotate_neg_tail_score(heads, tails, r_emb, num_chunks, The initial value used to bound the relation embedding initialization. gamma: float The gamma value used for shifting the optimization target. + device: th.device + Device to run the computation. Return ------ @@ -378,6 +384,163 @@ def calc_rotate_neg_tail_score(heads, tails, r_emb, num_chunks, rotate_score = gamma - score.sum(-1) return rotate_score +def calc_transe_pos_score(h_emb, t_emb, r_emb, gamma, norm='l2', device=None): + r""" Calculate TransE Score for positive pairs + + Score function of TransE measures the angular distance between + head and tail elements. The angular distance is defined as: + + .. math:: + + d_r(h, t)= -\|h+r-t\| + + The TransE score function is defined as: + + .. math:: + + gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\| + + where gamma is a margin. + + For more details, please refer to + https://papers.nips.cc/paper_files/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html + or https://dglke.dgl.ai/doc/kg.html#transe. + + Parameters + ---------- + h_emb: th.Tensor + Head node embedding. + t_emb: th.Tensor + Tail node embedding. + r_emb: th.Tensor + Relation type embedding. + gamma: float + The gamma value used for shifting the optimization target. + norm: str + L1 or L2 norm on the angular distance. + device: th.device + Device to run the computation. + + Return + ------ + transe_score: th.Tensor + The TransE score. + """ + if device is not None: + r_emb = r_emb.to(device) + h_emb = h_emb.to(device) + t_emb = t_emb.to(device) + + score = (h_emb + r_emb) - t_emb + + if norm == 'l1': + transe_score = gamma - th.norm(score, p=1, dim=-1) + elif norm == 'l2': + transe_score = gamma - th.norm(score, p=2, dim=-1) + else: + raise ValueError("Unknown norm on the angular distance. Only support L1 and L2.") + return transe_score + +def calc_transe_neg_head_score(h_emb, t_emb, r_emb, num_chunks, + chunk_size, neg_sample_size, + gamma, norm='l2', + device=None): + """ Calculate TransE Score for negative pairs when head nodes are negative. + + Parameters + ---------- + h_emb: th.Tensor + Head node embedding. + t_emb: th.Tensor + Tail node embedding. + r_emb: th.Tensor + Relation type embedding. + num_chunks: int + Number of shared negative chunks. + chunk_size: int + Chunk size. + neg_sample_size: int + Number of negative samples for each positive node. + gamma: float + The gamma value used for shifting the optimization target. + norm: str + L1 or L2 norm on the angular distance. + device: th.device + Device to run the computation. + + Return + ------ + transe_score: th.Tensor + The TransE score. + """ + if device is not None: + r_emb = r_emb.to(device) + h_emb = h_emb.to(device) + t_emb = t_emb.to(device) + + hidden_dim = h_emb.shape[1] + h_emb = h_emb.reshape(num_chunks, neg_sample_size, hidden_dim) + t_emb = t_emb - r_emb + t_emb = t_emb.reshape(num_chunks, chunk_size, hidden_dim) + + if norm == 'l1': + transe_score = gamma - th.cdist(t_emb, h_emb, p=1) + elif norm == 'l2': + transe_score = gamma - th.cdist(t_emb, h_emb, p=2) + else: + raise ValueError("Unknown norm on the angular distance. Only support L1 and L2.") + return transe_score + +def calc_transe_neg_tail_score(h_emb, t_emb, r_emb, num_chunks, + chunk_size, neg_sample_size, + gamma, norm='l2', + device=None): + """ Calculate TransE Score for negative pairs when tail nodes are negative. + + Parameters + ---------- + h_emb: th.Tensor + Head node embedding. + t_emb: th.Tensor + Tail node embedding. + r_emb: th.Tensor + Relation type embedding. + num_chunks: int + Number of shared negative chunks. + chunk_size: int + Chunk size. + neg_sample_size: int + Number of negative samples for each positive node. + gamma: float + The gamma value used for shifting the optimization target. + norm: str + L1 or L2 norm on the angular distance. + device: th.device + Device to run the computation. + + Return + ------ + transe_score: th.Tensor + The TransE score. + """ + if device is not None: + r_emb = r_emb.to(device) + h_emb = h_emb.to(device) + t_emb = t_emb.to(device) + + hidden_dim = h_emb.shape[1] + h_emb = h_emb + r_emb + h_emb = h_emb.reshape(num_chunks, chunk_size, hidden_dim) + t_emb = t_emb.reshape(num_chunks, neg_sample_size, hidden_dim) + + if norm == 'l1': + transe_score = gamma - th.cdist(h_emb, t_emb, p=1) + elif norm == 'l2': + transe_score = gamma - th.cdist(h_emb, t_emb, p=2) + else: + raise ValueError("Unknown norm on the angular distance. Only support L1 and L2.") + return transe_score + def calc_ranking(pos_score, neg_score): """ Calculate ranking of positive scores among negative scores diff --git a/python/graphstorm/gsf.py b/python/graphstorm/gsf.py index 45a2afc525..723959847d 100644 --- a/python/graphstorm/gsf.py +++ b/python/graphstorm/gsf.py @@ -39,7 +39,9 @@ BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) from .config import (BUILTIN_LP_DOT_DECODER, BUILTIN_LP_DISTMULT_DECODER, - BUILTIN_LP_ROTATE_DECODER) + BUILTIN_LP_ROTATE_DECODER, + BUILTIN_LP_TRANSE_L1_DECODER, + BUILTIN_LP_TRANSE_L2_DECODER) from .config import (BUILTIN_LP_LOSS_CROSS_ENTROPY, BUILTIN_LP_LOSS_CONTRASTIVELOSS, BUILTIN_CLASS_LOSS_CROSS_ENTROPY, @@ -78,7 +80,10 @@ LinkPredictWeightedDistMultDecoder, LinkPredictRotatEDecoder, LinkPredictContrastiveRotatEDecoder, - LinkPredictWeightedRotatEDecoder) + LinkPredictWeightedRotatEDecoder, + LinkPredictTransEDecoder, + LinkPredictContrastiveTransEDecoder, + LinkPredictWeightedTransEDecoder) from .dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER,BUILTIN_LP_INBATCH_JOINT_NEG_SAMPLER, BUILTIN_LP_LOCALUNIFORM_NEG_SAMPLER, @@ -724,6 +729,30 @@ def create_builtin_lp_decoder(g, decoder_input_dim, config, train_task): decoder_input_dim, gamma, config.lp_edge_weight_for_loss) + elif config.lp_decoder_type in [BUILTIN_LP_TRANSE_L1_DECODER, BUILTIN_LP_TRANSE_L2_DECODER]: + if get_rank() == 0: + logging.debug("Using TransE objective for supervision") + + # default gamma for TransE is 12. + gamma = config.gamma if config.gamma is not None else 12. + + score_norm = 'l1' if config.lp_decoder_type == BUILTIN_LP_TRANSE_L1_DECODER else 'l2' + if config.lp_edge_weight_for_loss is None: + decoder = LinkPredictContrastiveTransEDecoder(g.canonical_etypes, + decoder_input_dim, + gamma, + score_norm) \ + if config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS else \ + LinkPredictTransEDecoder(g.canonical_etypes, + decoder_input_dim, + gamma, + score_norm) + else: + decoder = LinkPredictWeightedTransEDecoder(g.canonical_etypes, + decoder_input_dim, + gamma, + score_norm, + config.lp_edge_weight_for_loss) else: raise Exception(f"Unknown link prediction decoder type {config.lp_decoder_type}") diff --git a/python/graphstorm/model/__init__.py b/python/graphstorm/model/__init__.py index f0c0792d74..eb2d8a603f 100644 --- a/python/graphstorm/model/__init__.py +++ b/python/graphstorm/model/__init__.py @@ -58,7 +58,10 @@ LinkPredictContrastiveDistMultDecoder, LinkPredictRotatEDecoder, LinkPredictContrastiveRotatEDecoder, - LinkPredictWeightedRotatEDecoder) + LinkPredictWeightedRotatEDecoder, + LinkPredictTransEDecoder, + LinkPredictContrastiveTransEDecoder, + LinkPredictWeightedTransEDecoder) from .gnn_encoder_base import GraphConvEncoder diff --git a/python/graphstorm/model/edge_decoder.py b/python/graphstorm/model/edge_decoder.py index 848137b019..0bf46e0687 100644 --- a/python/graphstorm/model/edge_decoder.py +++ b/python/graphstorm/model/edge_decoder.py @@ -28,11 +28,16 @@ BUILTIN_LP_JOINT_NEG_SAMPLER, BUILTIN_LP_FIXED_NEG_SAMPLER) -from ..eval.utils import calc_distmult_pos_score, calc_dot_pos_score, calc_rotate_pos_score +from ..eval.utils import (calc_distmult_pos_score, + calc_dot_pos_score, + calc_rotate_pos_score, + calc_transe_pos_score) from ..eval.utils import (calc_distmult_neg_head_score, calc_distmult_neg_tail_score, calc_rotate_neg_head_score, - calc_rotate_neg_tail_score) + calc_rotate_neg_tail_score, + calc_transe_neg_head_score, + calc_transe_neg_tail_score) # TODO(zhengda) we need to split it into classifier and regression. @@ -1504,6 +1509,399 @@ def forward(self, g, h, e_h): return scores +class LinkPredictTransEDecoder(LinkPredictMultiRelationLearnableDecoder): + r""" Decoder for link prediction using the TransE as the score function. + + Score function of TransE measures the angular distance between + head and tail elements. The angular distance is defined as: + + .. math:: + + d_r(h, t)= -\|h+r-t\| + + The TransE score function is defined as: + + .. math:: + + gamma - \|h+r-t\|^{frac{1}{2}} \text{or} gamma - \|h+r-t\| + + where gamma is a margin. + + For more details, please refer to + https://papers.nips.cc/paper_files/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html + or https://dglke.dgl.ai/doc/kg.html#transe. + + Parameters + ---------- + etypes: list of tuples + The canonical edge types of the graph in the format of + [(src_ntype1, etype1, dst_ntype1), ...] + h_dim: int + The input dimension size. It is the dimension for both source and destination + node embeddings. + gamma: float + The gamma value for model initialization and score function. Default: 12. + norm: str + L1 or L2 norm on the angular distance for TransE. Default: 'l2'. + """ + def __init__(self, + etypes, + h_dim, + gamma=12., + norm='l2'): + self.norm = norm + super(LinkPredictTransEDecoder, self).__init__(etypes, h_dim, gamma) + + def init_w_relation(self): + self._w_relation = nn.Embedding(self.num_rels, self.h_dim) + self.emb_init = self.gamma / self.h_dim + nn.init.uniform_(self._w_relation.weight, -self.emb_init, self.emb_init) + + # pylint: disable=unused-argument + def forward(self, g, h, e_h=None): + """ Link prediction decoder forward function using the TransE + as the score function. + + This computes the edge score on every edge type. + + Parameters + ---------- + g: DGLGraph + The input graph. + h: dict of Tensor + The input node embeddings in the format of {ntype: emb}. + e_h: dict of Tensor + The input edge embeddings in the format of {(src_ntype, etype, dst_ntype): emb}. + Not used, but reserved for future support of edge embeddings. Default: None. + + Returns + ------- + scores: dict of Tensor + The scores for edges of all edge types in the input graph in the format of + {(src_ntype, etype, dst_ntype): score}. + """ + with g.local_scope(): + scores = {} + + for canonical_etype in g.canonical_etypes: + if g.num_edges(canonical_etype) == 0: + continue # the block might contain empty edge types + + i = self.etype2rid[canonical_etype] + self.trained_rels[i] += 1 + rel_embedding = self._w_relation(th.tensor(i).to(self._w_relation.weight.device)) + rel_embedding = rel_embedding.unsqueeze(dim=0) + src_type, _, dest_type = canonical_etype + u, v = g.edges(etype=canonical_etype) + src_emb = h[src_type][u] + + dest_emb = h[dest_type][v] + scores_etype = calc_transe_pos_score(src_emb, + dest_emb, + rel_embedding, + self.gamma, + self.norm) + scores[canonical_etype] = scores_etype + + return scores + + def calc_test_scores(self, emb, pos_neg_tuple, neg_sample_type, device): + """ Compute scores for positive edges and negative edges. + + Parameters + ---------- + emb: dict of Tensor + Node embeddings in the format of {ntype: emb}. + pos_neg_tuple: dict of tuple + Positive and negative edges stored in a dict of tuple in the format of + {("src_ntype1", "etype1", "dst_ntype1" ): (pos_src_idx, neg_src_idx, + pos_dst_idx, neg_dst_idx)}. + + The `pos_src_idx` represents the postive source node indexes in the format + of Torch.Tensor. The `neg_src_idx` represents the negative source node indexes + in the format of Torch.Tensor. The `pos_dst_idx` represents the postive destination + node indexes in the format of Torch.Tensor. The `neg_dst_idx` represents the + negative destination node indexes in the format of Torch.Tensor. + + We define positive and negative edges as: + + * The positive edges: (pos_src_idx, pos_dst_idx) + * The negative edges: (pos_src_idx, neg_dst_idx) and + (neg_src_idx, pos_dst_idx) + + neg_sample_type: str + Describe how negative samples are sampled. There are two options: + + * ``Uniform``: For each positive edge, we sample K negative edges. + * ``Joint``: For one batch of positive edges, we sample K negative edges. + + device: th.device + Device used to compute scores. + + Returns + -------- + scores: dict of tuple + Return a dictionary of edge type's positive scores and negative scores in the format + of {(src_ntype, etype, dst_ntype): (pos_scores, neg_scores)}. + """ + assert isinstance(pos_neg_tuple, dict), \ + "TransE is only applicable to heterogeneous graphs." \ + "Otherwise please use dot product decoder." + scores = {} + for canonical_etype, (pos_src, neg_src, pos_dst, neg_dst) in pos_neg_tuple.items(): + utype, _, vtype = canonical_etype + # pos score + pos_src_emb = emb[utype][pos_src] + pos_dst_emb = emb[vtype][pos_dst] + rid = self.etype2rid[canonical_etype] + rel_embedding = self._w_relation( + th.tensor(rid).to(self._w_relation.weight.device)) + pos_scores = calc_transe_pos_score(pos_src_emb, + pos_dst_emb, + rel_embedding, + self.gamma, + self.norm, + device) + neg_scores = [] + + if neg_src is not None: + neg_src_emb = emb[utype][neg_src.reshape(-1,)] + if neg_sample_type in [BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER]: + # fixed negative sample is similar to uniform negative sample + neg_src_emb = neg_src_emb.reshape(neg_src.shape[0], neg_src.shape[1], -1) + # uniform sampled negative samples + pos_dst_emb = pos_dst_emb.reshape( + pos_dst_emb.shape[0], 1, pos_dst_emb.shape[1]) + rel_embedding = rel_embedding.reshape( + 1, 1, rel_embedding.shape[-1]) + neg_score = calc_transe_pos_score(neg_src_emb, + pos_dst_emb, + rel_embedding, + self.gamma, + self.norm, + device) + elif neg_sample_type == BUILTIN_LP_JOINT_NEG_SAMPLER: + # joint sampled negative samples + assert len(pos_dst_emb.shape) == 2, \ + "For joint negative sampler, in evaluation" \ + "positive src/dst embs should in shape of" \ + "[eval_batch_size, dimension size]" + assert len(neg_src_emb.shape) == 2, \ + "For joint negative sampler, in evaluation" \ + "negative src/dst embs should in shape of " \ + "[number_of_negs, dimension size]" + neg_score = calc_transe_neg_head_score( + neg_src_emb, pos_dst_emb, rel_embedding, + 1, pos_dst_emb.shape[0], neg_src_emb.shape[0], + self.gamma, self.norm, + device) + # shape (batch_size, num_negs) + neg_score = neg_score.reshape(-1, neg_src_emb.shape[0]) + else: + assert False, f"Unknow negative sample type {neg_sample_type}" + assert len(neg_score.shape) == 2 + neg_scores.append(neg_score) + + if neg_dst is not None: + if neg_sample_type in [BUILTIN_LP_UNIFORM_NEG_SAMPLER, + BUILTIN_LP_FIXED_NEG_SAMPLER]: + # fixed negative sample is similar to uniform negative sample + neg_dst_emb = emb[vtype][neg_dst.reshape(-1,)] + neg_dst_emb = neg_dst_emb.reshape(neg_dst.shape[0], neg_dst.shape[1], -1) + # uniform sampled negative samples + pos_src_emb = pos_src_emb.reshape( + pos_src_emb.shape[0], 1, pos_src_emb.shape[1]) + rel_embedding = rel_embedding.reshape( + 1, 1, rel_embedding.shape[-1]) + neg_score = calc_transe_pos_score(pos_src_emb, + neg_dst_emb, + rel_embedding, + self.gamma, + self.norm, + device) + elif neg_sample_type == BUILTIN_LP_JOINT_NEG_SAMPLER: + neg_dst_emb = emb[vtype][neg_dst] + # joint sampled negative samples + assert len(pos_src_emb.shape) == 2, \ + "For joint negative sampler, in evaluation " \ + "positive src/dst embs should in shape of" \ + "[eval_batch_size, dimension size]" + assert len(neg_dst_emb.shape) == 2, \ + "For joint negative sampler, in evaluation" \ + "negative src/dst embs should in shape of " \ + "[number_of_negs, dimension size]" + neg_score = calc_transe_neg_tail_score( + pos_src_emb, neg_dst_emb, rel_embedding, + 1, pos_src_emb.shape[0], neg_dst_emb.shape[0], + self.gamma, self.norm, + device) + # shape (batch_size, num_negs) + neg_score = neg_score.reshape(-1, neg_dst_emb.shape[0]) + else: + assert False, f"Unknow negative sample type {neg_sample_type}" + assert len(neg_score.shape) == 2 + neg_scores.append(neg_score) + neg_scores = th.cat(neg_scores, dim=-1).detach() + # gloo with cpu will consume less GPU memory + neg_scores = neg_scores.cpu() \ + if is_distributed() and get_backend() == "gloo" \ + else neg_scores + + pos_scores = pos_scores.detach() + pos_scores = pos_scores.cpu() \ + if is_distributed() and get_backend() == "gloo" \ + else pos_scores + scores[canonical_etype] = (pos_scores, neg_scores) + + return scores + + @property + def in_dims(self): + """ Return the input dimension size, which is given in class initialization. + """ + return self.h_dim + + @property + def out_dims(self): + """ Return ``1`` for link prediction tasks. + """ + return 1 + +class LinkPredictContrastiveTransEDecoder(LinkPredictTransEDecoder): + """ Decoder for link prediction designed for contrastive loss + using the TransE as the score function. + + Note: + ------ + This class is specifically implemented for contrastive loss. But + it could also be used by other pair-wise loss functions for link + prediction tasks. + + Parameters + ---------- + etypes: list of tuples + The canonical edge types of the graph in the format of + [(src_ntype1, etype1, dst_ntype1), ...] + h_dim: int + The input dimension size. It is the dimension for both source and destination + node embeddings. + gamma: float + The gamma value for model weight initialization. Default: 4. + """ + + # pylint: disable=unused-argument + def forward(self, g, h, e_h=None): + with g.local_scope(): + scores = {} + + for canonical_etype in g.canonical_etypes: + if g.num_edges(canonical_etype) == 0: + continue # the block might contain empty edge types + + i = self.etype2rid[canonical_etype] + self.trained_rels[i] += 1 + rel_embedding = self._w_relation(th.tensor(i).to(self._w_relation.weight.device)) + rel_embedding = rel_embedding.unsqueeze(dim=0) + src_type, _, dest_type = canonical_etype + u, v = g.edges(etype=canonical_etype) + # Sort edges according to source node ids + # The same function is invoked by computing both pos scores + # and neg scores, by sorting edges according to source nids + # the output scores of pos_score and neg_score are compatible. + # + # For example: + # + # pos pairs | neg pairs + # (10, 20) | (10, 3), (10, 1), (10, 0), (10, 22) + # (13, 6) | (13, 3), (13, 1), (13, 0), (13, 22) + # (29, 8) | (29, 3), (29, 1), (29, 0), (29, 22) + # + # TODO: use stable to keep the order of negatives. This may not + # be necessary + u_sort_idx = th.argsort(u, stable=True) + u = u[u_sort_idx] + v = v[u_sort_idx] + src_emb = h[src_type][u] + dest_emb = h[dest_type][v] + scores_etype = calc_transe_pos_score(src_emb, + dest_emb, + rel_embedding, + self.gamma, + self.norm) + scores[canonical_etype] = scores_etype + + return scores + +class LinkPredictWeightedTransEDecoder(LinkPredictTransEDecoder): + """Link prediction decoder with the score function of TransE + with edge weight. + + When computing loss, edge weights are used to adjust the loss. + + Parameters + ---------- + etypes: list of tuples + The canonical edge types of the graph in the format of + [(src_ntype1, etype1, dst_ntype1), ...] + h_dim: int + The input dimension size. It is the dimension for both source and destination + node embeddings. + gamma: float + The gamma value for model weight initialization. Default: 12. + norm: str + L1 or L2 norm on the angular distance for TransE. Default: 'l2'. + edge_weight_fields: dict of str + The edge feature field(s) storing the edge weights. + """ + def __init__(self, etypes, h_dim, gamma=12., norm='l2', edge_weight_fields=None): + self.norm = norm + self._edge_weight_fields = edge_weight_fields + super(LinkPredictWeightedTransEDecoder, self).__init__(etypes, h_dim, gamma) + + # pylint: disable=signature-differs + def forward(self, g, h, e_h): + """Forward function. + + This computes the TransE score on every edge type. + """ + with g.local_scope(): + scores = {} + + for canonical_etype in g.canonical_etypes: + if g.num_edges(canonical_etype) == 0: + continue # the block might contain empty edge types + + i = self.etype2rid[canonical_etype] + self.trained_rels[i] += 1 + rel_embedding = self._w_relation(th.tensor(i).to(self._w_relation.weight.device)) + rel_embedding = rel_embedding.unsqueeze(dim=0) + src_type, _, dest_type = canonical_etype + u, v = g.edges(etype=canonical_etype) + src_emb = h[src_type][u] + + dest_emb = h[dest_type][v] + scores_etype = calc_transe_pos_score(src_emb, + dest_emb, + rel_embedding, + self.gamma, + self.norm) + + if e_h is not None and canonical_etype in e_h.keys(): + weight = e_h[canonical_etype] + assert th.is_tensor(weight), \ + "The edge weight for Link prediction must be a torch tensor." \ + "LinkPredictWeightedTransEDecoder only accepts a 1D edge " \ + "feature as edge weight." + weight = weight.flatten() + else: + # current etype does not have weight + weight = th.ones((g.num_edges(canonical_etype),), + device=scores_etype.device) + scores[canonical_etype] = (scores_etype, + weight) + + return scores class LinkPredictDistMultDecoder(LinkPredictMultiRelationLearnableDecoder): """ Decoder for link prediction using the DistMult as the score function. diff --git a/tests/end2end-tests/graphbolt-gs-integration/graphbolt-graph-construction.sh b/tests/end2end-tests/graphbolt-gs-integration/graphbolt-graph-construction.sh index fd6596f9f0..4e32aed053 100644 --- a/tests/end2end-tests/graphbolt-gs-integration/graphbolt-graph-construction.sh +++ b/tests/end2end-tests/graphbolt-gs-integration/graphbolt-graph-construction.sh @@ -6,7 +6,7 @@ usage() { cat </dev/null && pwd -P) + +usage() { + cat <&2 -e "${1-}" +} + +# Parse command-line arguments +parse_params() { + # Default values for input and output paths + INPUT_PATH="/data/ml-100k/" + OUTPUT_PATH="/tmp/gb-training-e2e-tests" + + while :; do + case "${1-}" in + -h | --help) usage ;; + -x | --verbose) set -x ;; + -i | --ml100k-path) + INPUT_PATH="${2-}" + shift + ;; + -o | --output-path) + OUTPUT_PATH="${2-}" + shift + ;; + -?*) die "Unknown option: $1" ;; + *) break ;; + esac + shift + done + + return 0 +} + +cleanup() { + trap - SIGINT SIGTERM ERR EXIT + # script cleanup here + if [[ -d "${OUTPUT_PATH}" ]]; then + echo "Cleaning up ${OUTPUT_PATH}" + rm -rf "${OUTPUT_PATH}" + fi +} + +fdir_exists() { + # Take two args: first should be f or d, for file or directory + # second is the path to check + + if [ "$1" == "f" ] + then + if [ ! -f "$2" ] + then + msg "$2 must exist" + exit 1 + fi + elif [ "$1" == "d" ] + then + if [ ! -d "$2" ] + then + msg "$2 must exist" + exit 1 + fi + else + msg "First arg to fdir_exists must be f or d" + exit 1 + fi +} + +parse_params "$@" + +GS_HOME=$(pwd) + +mkdir -p "$OUTPUT_PATH" +cp -R "$INPUT_PATH" "$OUTPUT_PATH" + +# Ensure ip_list.txt exists and self-ssh works +rm "$OUTPUT_PATH/ip_list.txt" &> /dev/null || true +echo "127.0.0.1" > "$OUTPUT_PATH/ip_list.txt" +ssh -o PreferredAuthentications=publickey -o StrictHostKeyChecking=no \ + -p 2222 127.0.0.1 /bin/true || service ssh restart + +# Generate 1P LP data +msg "**************GraphBolt Link Prediction data generation **************" +LP_INPUT_1P="${OUTPUT_PATH}/graphbolt-gconstruct-lp-1p" +python3 -m graphstorm.gconstruct.construct_graph \ + --add-reverse-edges \ + --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens_lp.json \ + --graph-name ml-lp \ + --num-parts 1 \ + --num-processes 1 \ + --output-dir "$LP_INPUT_1P" \ + --part-method random \ + --use-graphbolt "true" + +LP_OUTPUT="$OUTPUT_PATH/gb-lp" +msg "**************GraphBolt Link Prediction training. dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch, negative_sampler: joint, exclude_training_targets: false" +python3 -m graphstorm.run.gs_link_prediction \ + --cf $GS_HOME/training_scripts/gsgnn_lp/ml_lp.yaml \ + --eval-frequency 300 \ + --ip-config "$OUTPUT_PATH/ip_list.txt" \ + --num-epochs 1 \ + --num-samplers 0 \ + --num-servers 1 \ + --num-trainers 1 \ + --part-config "$LP_INPUT_1P/ml-lp.json" \ + --save-model-path "$LP_OUTPUT/model" \ + --ssh-port 2222 \ + --use-graphbolt true + +# Ensure model files were saved +fdir_exists f "$LP_OUTPUT/model/epoch-0/model.bin" +fdir_exists f "$LP_OUTPUT/model/epoch-0/optimizers.bin" + +msg " **************GraphBolt Link Prediction embedding generation **************" + +python3 -m graphstorm.run.gs_gen_node_embedding \ + --cf $GS_HOME/training_scripts/gsgnn_lp/ml_lp.yaml \ + --eval-frequency 300 \ + --inference \ + --ip-config "$OUTPUT_PATH/ip_list.txt" \ + --num-epochs 1 \ + --num-samplers 0 \ + --num-servers 1 \ + --num-trainers 1 \ + --part-config "$LP_INPUT_1P/ml-lp.json" \ + --restore-model-path "$LP_OUTPUT/model/epoch-0" \ + --save-embed-path "$LP_OUTPUT/embeddings" \ + --ssh-port 2222 \ + --use-graphbolt true + +# Ensure embeddings were created +fdir_exists d "$LP_OUTPUT/embeddings/movie" +fdir_exists d "$LP_OUTPUT/embeddings/user" + +LP_OUTPUT="$OUTPUT_PATH/gb-lp-inbatch_joint" +msg "**************GraphBolt Link Prediction training. dataset: Movielens, RGCN layer 1, inference: mini-batch, negative_sampler: inbatch_joint, exclude_training_targets: true" +python3 -m graphstorm.run.gs_link_prediction \ + --cf $GS_HOME/training_scripts/gsgnn_lp/ml_lp.yaml \ + --eval-frequency 300 \ + --exclude-training-targets True \ + --ip-config "$OUTPUT_PATH/ip_list.txt" \ + --num-epochs 1 \ + --num-samplers 0 \ + --num-servers 1 \ + --num-trainers 1 \ + --part-config "$LP_INPUT_1P/ml-lp.json" \ + --reverse-edge-types-map user,rating,rating-rev,movie \ + --save-model-path "$LP_OUTPUT/model" \ + --ssh-port 2222 \ + --train-negative-sampler inbatch_joint \ + --use-graphbolt true + +# Ensure model files were saved +fdir_exists f "$LP_OUTPUT/model/epoch-0/model.bin" +fdir_exists f "$LP_OUTPUT/model/epoch-0/optimizers.bin" + +LP_OUTPUT="$OUTPUT_PATH/gb-lp-all_etype_uniform" +msg "**************GraphBolt Link Prediction training. dataset: Movielens, RGCN layer 1, inference: mini-batch, negative_sampler: all_etype_uniform, exclude_training_targets: true" +python3 -m graphstorm.run.gs_link_prediction \ + --cf $GS_HOME/training_scripts/gsgnn_lp/ml_lp.yaml \ + --eval-frequency 300 \ + --exclude-training-targets True \ + --ip-config "$OUTPUT_PATH/ip_list.txt" \ + --num-epochs 1 \ + --num-samplers 0 \ + --num-servers 1 \ + --num-trainers 1 \ + --part-config "$LP_INPUT_1P/ml-lp.json" \ + --reverse-edge-types-map user,rating,rating-rev,movie \ + --save-model-path "$LP_OUTPUT/model" \ + --ssh-port 2222 \ + --train-negative-sampler all_etype_uniform \ + --use-graphbolt true + +# Ensure model file were saved +fdir_exists f "$LP_OUTPUT/model/epoch-0/model.bin" +fdir_exists f "$LP_OUTPUT/model/epoch-0/optimizers.bin" + + +# Generate 1P NC data +msg "************** GraphBolt Node Classification data generation. **************" +NC_INPUT_1P="${OUTPUT_PATH}/graphbolt-gconstruct-nc-1p" +python3 -m graphstorm.gconstruct.construct_graph \ + --add-reverse-edges \ + --conf-file $GS_HOME/tests/end2end-tests/data_gen/movielens.json \ + --graph-name ml-nc \ + --num-parts 1 \ + --num-processes 1 \ + --output-dir "$NC_INPUT_1P" \ + --part-method random \ + --use-graphbolt "true" + + +msg "************** GraphBolt Node Classification training. dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: mini-batch" +NC_OUTPUT="$OUTPUT_PATH/gb-nc" +python3 -m graphstorm.run.gs_node_classification \ + --cf $GS_HOME/training_scripts/gsgnn_np/ml_nc.yaml \ + --eval-frequency 300 \ + --ip-config "$OUTPUT_PATH/ip_list.txt" \ + --num-epochs 1 \ + --num-samplers 0 \ + --num-servers 1 \ + --num-trainers 1 \ + --part-config "$NC_INPUT_1P/ml-nc.json" \ + --save-model-path "$NC_OUTPUT/model" \ + --ssh-port 2222 \ + --use-graphbolt true + +# Ensure model files were saved +fdir_exists f "$NC_OUTPUT/model/epoch-0/model.bin" +fdir_exists f "$NC_OUTPUT/model/epoch-0/optimizers.bin" + +msg "************** GraphBolt Node Classification inference. **************" +python3 -m graphstorm.run.gs_node_classification \ + --cf $GS_HOME/training_scripts/gsgnn_np/ml_nc.yaml \ + --eval-frequency 300 \ + --inference \ + --ip-config "$OUTPUT_PATH/ip_list.txt" \ + --no-validation true \ + --num-epochs 1 \ + --num-samplers 0 \ + --num-servers 1 \ + --num-trainers 1 \ + --part-config "$NC_INPUT_1P/ml-nc.json" \ + --restore-model-path "$NC_OUTPUT/model/epoch-0" \ + --save-embed-path "$NC_OUTPUT/embeddings" \ + --save-prediction-path "$NC_OUTPUT/predictions" \ + --ssh-port 2222 \ + --use-graphbolt true \ + --use-mini-batch-infer false + +# Ensure embeddings and predictions were created +fdir_exists d "$NC_OUTPUT/embeddings" +fdir_exists d "$NC_OUTPUT/predictions" + +msg "********* GraphBolt training and inference tests passed *********" diff --git a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh index cbfeac95b0..2e4a9cf780 100644 --- a/tests/end2end-tests/graphstorm-lp/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-lp/mgpu_test.sh @@ -851,4 +851,98 @@ python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scrip error_and_exit $? rm -fr /data/gsgnn_lp_ml_rotate/* +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: TransE_L1, exclude_training_targets: true, save model" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-model-path /data/gsgnn_lp_ml_transe_l1/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_transe_l1/emb/ --lp-decoder-type transe_l1 --train-etype user,rating,movie movie,rating-rev,user --logging-file /tmp/train_log.txt --preserve-input True + +error_and_exit $? + +cnt=$(ls -l /data/gsgnn_lp_ml_transe_l1/ | grep epoch | wc -l) +if test $cnt != 1 +then + echo "The number of save models $cnt is not equal to the specified topk 1" + exit -1 +fi + +best_epoch_transe_l1=$(grep "successfully save the model to" /tmp/train_log.txt | tail -1 | tr -d '\n' | tail -c 1) +echo "The best model is saved in epoch $best_epoch_transe_l1" + +echo "**************dataset: Movielens, do inference on saved model, decoder: TransE_L1" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_transe_l1/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_transe_l1/epoch-$best_epoch_transe_l1/ --lp-decoder-type transe_l1 --no-validation False --train-etype user,rating,movie movie,rating-rev,user --preserve-input True + +error_and_exit $? + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_lp_ml_transe_l1/emb/ --infer-embout /data/gsgnn_lp_ml_transe_l1/infer-emb/ --link-prediction + +error_and_exit $? + +cnt=$(ls /data/gsgnn_lp_ml_transe_l1/infer-emb/ | grep rel_emb.pt | wc -l) +if test $cnt -ne 1 +then + echo "TransE_L1 inference outputs edge embedding" + exit -1 +fi + +cnt=$(ls /data/gsgnn_lp_ml_transe_l1/infer-emb/ | grep relation2id_map.json | wc -l) +if test $cnt -ne 1 +then + echo "TransE_L1 inference outputs edge embedding" + exit -1 +fi + +rm /tmp/train_log.txt +rm -fr /data/gsgnn_lp_ml_transe_l1/* + +echo "**************dataset: Movielens, two training edges but only one with edge weight for loss, score func: TransE_L1 ***********" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --topk-model-to-save 1 --save-model-frequency 1000 --train-etype user,rating,movie movie,rating-rev,user --lp-edge-weight-for-loss user,rating,movie:rate --lp-decoder-type transe_l1 --save-model-path /data/gsgnn_lp_ml_transe_l1/ + +error_and_exit $? +rm -fr /data/gsgnn_lp_ml_transe_l1/* + +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: TransE_L2, exclude_training_targets: true, save model" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-model-path /data/gsgnn_lp_ml_transe_l2/ --topk-model-to-save 1 --save-model-frequency 1000 --save-embed-path /data/gsgnn_lp_ml_transe_l2/emb/ --lp-decoder-type transe_l2 --train-etype user,rating,movie movie,rating-rev,user --logging-file /tmp/train_log.txt --preserve-input True + +error_and_exit $? + +cnt=$(ls -l /data/gsgnn_lp_ml_transe_l2/ | grep epoch | wc -l) +if test $cnt != 1 +then + echo "The number of save models $cnt is not equal to the specified topk 1" + exit -1 +fi + +best_epoch_transe_l2=$(grep "successfully save the model to" /tmp/train_log.txt | tail -1 | tr -d '\n' | tail -c 1) +echo "The best model is saved in epoch $best_epoch_transe_l2" + +echo "**************dataset: Movielens, do inference on saved model, decoder: TransE_L2" +python3 -m graphstorm.run.gs_link_prediction --inference --workspace $GS_HOME/inference_scripts/lp_infer --num-trainers $NUM_INFO_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp_infer.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --use-node-embeddings true --eval-batch-size 1024 --save-embed-path /data/gsgnn_lp_ml_transe_l2/infer-emb/ --restore-model-path /data/gsgnn_lp_ml_transe_l2/epoch-$best_epoch_transe_l2/ --lp-decoder-type transe_l2 --no-validation False --train-etype user,rating,movie movie,rating-rev,user --preserve-input True + +error_and_exit $? + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_lp_ml_transe_l2/emb/ --infer-embout /data/gsgnn_lp_ml_transe_l2/infer-emb/ --link-prediction + +error_and_exit $? + +cnt=$(ls /data/gsgnn_lp_ml_transe_l2/infer-emb/ | grep rel_emb.pt | wc -l) +if test $cnt -ne 1 +then + echo "TransE_L2 inference outputs edge embedding" + exit -1 +fi + +cnt=$(ls /data/gsgnn_lp_ml_transe_l2/infer-emb/ | grep relation2id_map.json | wc -l) +if test $cnt -ne 1 +then + echo "TransE_L2 inference outputs edge embedding" + exit -1 +fi + +rm /tmp/train_log.txt +rm -fr /data/gsgnn_lp_ml_transe_l2/* + +echo "**************dataset: Movielens, two training edges but only one with edge weight for loss, score func: TransE_L2 ***********" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --eval-batch-size 1024 --topk-model-to-save 1 --save-model-frequency 1000 --train-etype user,rating,movie movie,rating-rev,user --lp-edge-weight-for-loss user,rating,movie:rate --lp-decoder-type transe_l2 --save-model-path /data/gsgnn_lp_ml_transe_l2/ + +error_and_exit $? +rm -fr /data/gsgnn_lp_ml_transe_l2/* + rm -fr /tmp/* diff --git a/tests/end2end-tests/graphstorm-lp/test.sh b/tests/end2end-tests/graphstorm-lp/test.sh index e74978c4b1..bef48dd72d 100644 --- a/tests/end2end-tests/graphstorm-lp/test.sh +++ b/tests/end2end-tests/graphstorm-lp/test.sh @@ -156,9 +156,18 @@ python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scrip error_and_exit $? - echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: RotatE, exclude_training_targets: false, contrastive loss" -python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type rotate --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --contrastive-loss-temperature 0.1 --lp-loss-func contrastive --logging-level debug +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type rotate --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --contrastive-loss-temperature 0.1 --lp-loss-func contrastive + +error_and_exit $? + +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: TransE_l2, exclude_training_targets: false, contrastive loss" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type transe_l2 --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --contrastive-loss-temperature 1 --lp-loss-func contrastive + +error_and_exit $? + +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: TransE_l1, exclude_training_targets: false, contrastive loss" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type transe_l1 --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --contrastive-loss-temperature 1 --lp-loss-func contrastive error_and_exit $? @@ -167,8 +176,18 @@ python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scrip error_and_exit $? -echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: RotatE, exclude_training_targets: false, adversarial cross entropy loss" -python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type rotate --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --lp-loss-func cross_entropy --adversarial-temperature 0.1 --lp-embed-normalizer l2_norm --logging-level debug +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: RotatE, exclude_training_targets: false, lp-embed-normalizer: l2_norm, adversarial cross entropy loss" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type rotate --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --lp-loss-func cross_entropy --adversarial-temperature 0.1 --lp-embed-normalizer l2_norm + +error_and_exit $? + +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: TransE_L2, exclude_training_targets: false, adversarial cross entropy loss" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type transe_l2 --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --lp-loss-func cross_entropy --adversarial-temperature 0.1 + +error_and_exit $? + +echo "**************dataset: Movielens, RGCN layer 2, node feat: fixed HF BERT & sparse embed, BERT nodes: movie, inference: full-graph, negative_sampler: joint, decoder: TransE_L1, exclude_training_targets: false, lp-embed-normalizer: l2_norm, adversarial cross entropy loss" +python3 -m graphstorm.run.gs_link_prediction --workspace $GS_HOME/training_scripts/gsgnn_lp --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lp_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_lp.yaml --fanout '10,15' --num-layers 2 --use-mini-batch-infer false --lp-decoder-type transe_l1 --train-etype user,rating,movie movie,rating-rev,user --num-epochs 1 --eval-frequency 300 --lp-loss-func cross_entropy --adversarial-temperature 0.1 --lp-embed-normalizer l2_norm error_and_exit $? diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 171becc7aa..6a282a02bc 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -40,8 +40,11 @@ from graphstorm.dataloading import BUILTIN_LP_UNIFORM_NEG_SAMPLER from graphstorm.dataloading import BUILTIN_LP_JOINT_NEG_SAMPLER from graphstorm.config.config import GRAPHSTORM_SAGEMAKER_TASK_TRACKER -from graphstorm.config import BUILTIN_LP_DOT_DECODER -from graphstorm.config import BUILTIN_LP_DISTMULT_DECODER +from graphstorm.config import (BUILTIN_LP_DOT_DECODER, + BUILTIN_LP_DISTMULT_DECODER, + BUILTIN_LP_ROTATE_DECODER, + BUILTIN_LP_TRANSE_L1_DECODER, + BUILTIN_LP_TRANSE_L2_DECODER) from graphstorm.config.config import LINK_PREDICTION_MAJOR_EVAL_ETYPE_ALL def check_failure(config, field): @@ -1001,7 +1004,7 @@ def create_lp_config(tmp_path, file_name): "exclude_training_targets": "error", "reverse_edge_types_map": "query,exactmatch,rev-exactmatch,asin", "lp_loss_func": "unknown", - "lp_decoder_type": "transe", + "lp_decoder_type": "transr", "lp_edge_weight_for_loss": ["query,click,asin:weight1"], "model_select_etype": "fail" } @@ -1049,6 +1052,30 @@ def create_lp_config(tmp_path, file_name): with open(os.path.join(tmp_path, file_name+"_adv_temp_fail.yaml"), "w") as f: yaml.dump(yaml_object, f) + yaml_object["gsf"]["link_prediction"] = { + "lp_decoder_type": "rotate", + } + with open(os.path.join(tmp_path, file_name+"_rotate.yaml"), "w") as f: + yaml.dump(yaml_object, f) + + yaml_object["gsf"]["link_prediction"] = { + "lp_decoder_type": "transe_l1", + } + with open(os.path.join(tmp_path, file_name + "_transe_l1.yaml"), "w") as f: + yaml.dump(yaml_object, f) + + yaml_object["gsf"]["link_prediction"] = { + "lp_decoder_type": "transe_l2", + } + with open(os.path.join(tmp_path, file_name + "_transe_l2.yaml"), "w") as f: + yaml.dump(yaml_object, f) + + yaml_object["gsf"]["link_prediction"] = { + "lp_decoder_type": "transe_l3", + } + with open(os.path.join(tmp_path, file_name + "_fail_transe.yaml"), "w") as f: + yaml.dump(yaml_object, f) + def test_lp_info(): with tempfile.TemporaryDirectory() as tmpdirname: create_lp_config(Path(tmpdirname), 'lp_test') @@ -1154,6 +1181,30 @@ def test_lp_info(): assert config.lp_loss_func == BUILTIN_LP_LOSS_CONTRASTIVELOSS check_failure(config, "adversarial_temperature") + args = Namespace( + yaml_config_file=os.path.join(Path(tmpdirname), 'lp_test_rotate.yaml'), + local_rank=0) + config = GSConfig(args) + assert config.lp_decoder_type == BUILTIN_LP_ROTATE_DECODER + + args = Namespace( + yaml_config_file=os.path.join(Path(tmpdirname), 'lp_test_transe_l1.yaml'), + local_rank=0) + config = GSConfig(args) + assert config.lp_decoder_type == BUILTIN_LP_TRANSE_L1_DECODER + + args = Namespace( + yaml_config_file=os.path.join(Path(tmpdirname), 'lp_test_transe_l2.yaml'), + local_rank=0) + config = GSConfig(args) + assert config.lp_decoder_type == BUILTIN_LP_TRANSE_L2_DECODER + + args = Namespace( + yaml_config_file=os.path.join(Path(tmpdirname), 'lp_test_fail_transe.yaml'), + local_rank=0) + config = GSConfig(args) + check_failure(config, "lp_decoder_type") + def create_gnn_config(tmp_path, file_name): yaml_object = create_dummpy_config_obj() yaml_object["gsf"]["link_prediction"] = {} diff --git a/tests/unit-tests/test_decoder.py b/tests/unit-tests/test_decoder.py index cf91d7d5f0..6ffe914eb7 100644 --- a/tests/unit-tests/test_decoder.py +++ b/tests/unit-tests/test_decoder.py @@ -27,11 +27,14 @@ LinkPredictContrastiveDotDecoder, LinkPredictContrastiveDistMultDecoder, LinkPredictRotatEDecoder, - LinkPredictContrastiveRotatEDecoder) + LinkPredictContrastiveRotatEDecoder, + LinkPredictTransEDecoder, + LinkPredictContrastiveTransEDecoder) from graphstorm.dataloading import (BUILTIN_LP_UNIFORM_NEG_SAMPLER, BUILTIN_LP_JOINT_NEG_SAMPLER) from graphstorm.eval.utils import (calc_distmult_pos_score, - calc_rotate_pos_score) + calc_rotate_pos_score, + calc_transe_pos_score) from graphstorm.eval.utils import calc_dot_pos_score from graphstorm.eval.utils import calc_ranking @@ -60,6 +63,9 @@ def _check_ranking(score, pos_score, neg_scores, etype, num_neg, batch_size): p_score = score[etype][0].cpu() n_score = score[etype][1].cpu() + assert_almost_equal(p_score, pos_score.cpu().numpy(), decimal=5) + assert_almost_equal(n_score, neg_scores.cpu().numpy(), decimal=5) + test_ranking = calc_ranking(p_score, n_score) ranking = calc_ranking(pos_score.cpu(), neg_scores.cpu()) @@ -234,6 +240,174 @@ def gen_edge_pairs(): neg_scores = th.stack(neg_scores) _check_ranking(score, pos_score, neg_scores, etypes[0], num_neg*2, pos_src.shape[0]) +def check_calc_test_scores_transe_uniform_neg(decoder, etypes, h_dim, num_pos, num_neg, device): + neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER + emb = { + 'a': th.rand((128, h_dim)), + 'b': th.rand((128, h_dim)), + } + + def gen_edge_pairs(): + pos_src = th.randint(100, (num_pos,)) + pos_dst = th.randint(100, (num_pos,)) + neg_src = th.randint(128, (num_pos, num_neg)) + neg_dst = th.randint(128, (num_pos, num_neg)) + return (pos_src, neg_src, pos_dst, neg_dst) + + with th.no_grad(): + pos_neg_tuple = { + etypes[0]: gen_edge_pairs(), + etypes[1]: gen_edge_pairs(), + } + pos_src, neg_src, pos_dst, neg_dst = pos_neg_tuple[etypes[0]] + pos_neg_tuple[etypes[0]] = (pos_src, None, pos_dst, neg_dst) + pos_src, neg_src, pos_dst, neg_dst = pos_neg_tuple[etypes[1]] + pos_neg_tuple[etypes[1]] = (pos_src, neg_src, pos_dst, None) + + score = decoder.calc_test_scores(emb, pos_neg_tuple, neg_sample_type, device) + pos_src, _, pos_dst, neg_dst = pos_neg_tuple[etypes[0]] + pos_src_emb = emb[etypes[0][0]][pos_src] + pos_dst_emb = emb[etypes[0][2]][pos_dst] + rel_emb = decoder.get_relemb(etypes[0]) + pos_score = calc_transe_pos_score(pos_src_emb, pos_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores = [] + for i in range(pos_src.shape[0]): + pse = pos_src_emb[i] + neg_dst_emb = emb[etypes[0][2]][neg_dst[i]] + ns = calc_transe_pos_score(pse, neg_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores.append(ns) + neg_scores = th.stack(neg_scores) + _check_scores(score, pos_score, neg_scores, etypes[0], num_neg, pos_src.shape[0]) + + pos_src, neg_src, pos_dst, _ = pos_neg_tuple[etypes[1]] + pos_src_emb = emb[etypes[1][0]][pos_src] + pos_dst_emb = emb[etypes[1][2]][pos_dst] + rel_emb = decoder.get_relemb(etypes[1]) + pos_score = calc_transe_pos_score(pos_src_emb, pos_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores = [] + for i in range(pos_dst.shape[0]): + neg_src_emb = emb[etypes[1][0]][neg_src[i]] + pde = pos_dst_emb[i] + ns = calc_transe_pos_score(neg_src_emb, pde, rel_emb, + decoder.gamma, decoder.norm) + neg_scores.append(ns) + neg_scores = th.stack(neg_scores) + _check_scores(score, pos_score, neg_scores, etypes[1], num_neg, pos_src.shape[0]) + + pos_neg_tuple = { + etypes[0]: gen_edge_pairs(), + } + score = decoder.calc_test_scores(emb, pos_neg_tuple, neg_sample_type, device) + pos_src, neg_src, pos_dst, neg_dst = pos_neg_tuple[etypes[0]] + pos_src_emb = emb[etypes[0][0]][pos_src] + pos_dst_emb = emb[etypes[0][2]][pos_dst] + rel_emb = decoder.get_relemb(etypes[0]) + pos_score = calc_transe_pos_score(pos_src_emb, pos_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores = [] + for i in range(pos_src.shape[0]): + pse = pos_src_emb[i] + pde = pos_dst_emb[i] + neg_src_emb = emb[etypes[0][0]][neg_src[i]] + neg_dst_emb = emb[etypes[0][2]][neg_dst[i]] + # (num_neg, dim) * (dim) * (dim) + ns_0 = calc_transe_pos_score(neg_src_emb, pde, rel_emb, + decoder.gamma, decoder.norm) + # (dim) * (dim) * (num_neg, dim) + ns_1 = calc_transe_pos_score(pse, neg_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores.append(th.cat((ns_0, ns_1), dim=-1)) + neg_scores = th.stack(neg_scores) + _check_scores(score, pos_score, neg_scores, etypes[0], num_neg*2, pos_src.shape[0]) + +def check_calc_test_scores_transe_joint_neg(decoder, etypes, h_dim, num_pos, num_neg, device): + neg_sample_type = BUILTIN_LP_JOINT_NEG_SAMPLER + emb = { + 'a': th.rand((128, h_dim)), + 'b': th.rand((128, h_dim)), + } + + def gen_edge_pairs(): + pos_src = th.ones((num_pos,), dtype=int) + pos_dst = th.randint(100, (num_pos,)) + neg_src = th.randint(128, (num_neg,)) + neg_dst = th.randint(128, (num_neg,)) + neg_src[neg_src==1] = 2 + return (pos_src, neg_src, pos_dst, neg_dst) + + with th.no_grad(): + pos_neg_tuple = { + etypes[0]: gen_edge_pairs(), + etypes[1]: gen_edge_pairs(), + } + pos_src, neg_src, pos_dst, neg_dst = pos_neg_tuple[etypes[0]] + pos_neg_tuple[etypes[0]] = (pos_src, None, pos_dst, neg_dst) + pos_src, neg_src, pos_dst, neg_dst = pos_neg_tuple[etypes[1]] + pos_neg_tuple[etypes[1]] = (pos_src, neg_src, pos_dst, None) + + score = decoder.calc_test_scores(emb, pos_neg_tuple, neg_sample_type, device) + pos_src, _, pos_dst, neg_dst = pos_neg_tuple[etypes[0]] + pos_src_emb = emb[etypes[0][0]][pos_src] + pos_dst_emb = emb[etypes[0][2]][pos_dst] + rel_emb = decoder.get_relemb(etypes[0]) + pos_score = calc_transe_pos_score(pos_src_emb, pos_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores = [] + for i in range(pos_src.shape[0]): + pse = pos_src_emb[i] + neg_dst_emb = emb[etypes[0][2]][neg_dst] + # (dim) * (dim) * (num_neg, dim) + ns = calc_transe_pos_score(pse, neg_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores.append(ns) + neg_scores = th.stack(neg_scores) + _check_ranking(score, pos_score, neg_scores, etypes[0], num_neg, pos_src.shape[0]) + + pos_src, neg_src, pos_dst, _ = pos_neg_tuple[etypes[1]] + pos_src_emb = emb[etypes[1][0]][pos_src] + pos_dst_emb = emb[etypes[1][2]][pos_dst] + rel_emb = decoder.get_relemb(etypes[1]) + pos_score = calc_transe_pos_score(pos_src_emb, pos_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores = [] + for i in range(pos_dst.shape[0]): + neg_src_emb = emb[etypes[1][0]][neg_src] + pde = pos_dst_emb[i] + # (num_neg, dim) * (dim) * (dim) + ns = calc_transe_pos_score(neg_src_emb, pde, rel_emb, + decoder.gamma, decoder.norm) + neg_scores.append(ns) + neg_scores = th.stack(neg_scores) + _check_ranking(score, pos_score, neg_scores, etypes[1], num_neg, pos_src.shape[0]) + + pos_neg_tuple = { + etypes[0]: gen_edge_pairs(), + } + score = decoder.calc_test_scores(emb, pos_neg_tuple, neg_sample_type, device) + pos_src, neg_src, pos_dst, neg_dst = pos_neg_tuple[etypes[0]] + pos_src_emb = emb[etypes[0][0]][pos_src] + pos_dst_emb = emb[etypes[0][2]][pos_dst] + rel_emb = decoder.get_relemb(etypes[0]) + pos_score = calc_transe_pos_score(pos_src_emb, pos_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores = [] + for i in range(pos_src.shape[0]): + pse = pos_src_emb[i] + pde = pos_dst_emb[i] + neg_src_emb = emb[etypes[0][0]][neg_src] + neg_dst_emb = emb[etypes[0][2]][neg_dst] + + ns_0 = calc_transe_pos_score(neg_src_emb, pde, rel_emb, + decoder.gamma, decoder.norm) + ns_1 = calc_transe_pos_score(pse, neg_dst_emb, rel_emb, + decoder.gamma, decoder.norm) + neg_scores.append(th.cat((ns_0, ns_1), dim=-1)) + neg_scores = th.stack(neg_scores) + _check_ranking(score, pos_score, neg_scores, etypes[0], num_neg*2, pos_src.shape[0]) + def check_calc_test_scores_uniform_neg(decoder, etypes, h_dim, num_pos, num_neg, device): neg_sample_type = BUILTIN_LP_UNIFORM_NEG_SAMPLER @@ -538,6 +712,21 @@ def test_LinkPredictRotatEDecoder(h_dim, num_pos, num_neg, device): check_calc_test_scores_rotate_uniform_neg(decoder, etypes, h_dim, num_pos, num_neg, device) check_calc_test_scores_rotate_joint_neg(decoder, etypes, h_dim, num_pos, num_neg, device) +@pytest.mark.parametrize("h_dim", [16, 64]) +@pytest.mark.parametrize("num_pos", [8, 32]) +@pytest.mark.parametrize("num_neg", [1, 32]) +@pytest.mark.parametrize("device",["cpu","cuda:0"]) +def test_LinkPredictTransEDecoder(h_dim, num_pos, num_neg, device): + th.manual_seed(0) + etypes = [('a', 'r1', 'b'), ('a', 'r2', 'b')] + decoder = LinkPredictTransEDecoder(etypes, h_dim, gamma=12.) + # mimic that decoder has been trained. + decoder.trained_rels[0] = 1 + decoder.trained_rels[1] = 1 + + check_calc_test_scores_transe_uniform_neg(decoder, etypes, h_dim, num_pos, num_neg, device) + check_calc_test_scores_transe_joint_neg(decoder, etypes, h_dim, num_pos, num_neg, device) + @pytest.mark.parametrize("h_dim", [16, 64]) @pytest.mark.parametrize("num_pos", [8, 32]) @pytest.mark.parametrize("num_neg", [1, 32]) @@ -670,6 +859,48 @@ def comput_score(src_emb, dst_emb): check_forward(decoder, etype, h_dim, num_pos, num_neg, comput_score, device) +@pytest.mark.parametrize("h_dim", [16, 64]) +@pytest.mark.parametrize("num_pos", [8, 32]) +@pytest.mark.parametrize("num_neg", [1, 32]) +@pytest.mark.parametrize("device",["cpu", "cuda:0"]) +def test_LinkPredictContrastiveTransEDecoder_L1norm(h_dim, num_pos, num_neg, device): + th.manual_seed(1) + etype = ('a', 'r1', 'b') + gamma = 4. + norm = 'l1' + decoder = LinkPredictContrastiveTransEDecoder([etype], h_dim, gamma=gamma) + decoder.trained_rels[0] = 1 # trick the decoder + decoder.norm = norm + decoder = decoder.to(device) + rel_emb = decoder.get_relemb(etype).to(device) + + def comput_score(src_emb, dst_emb): + score = (src_emb + rel_emb) - dst_emb + return gamma - th.norm(score, p=1, dim=-1) + + check_forward(decoder, etype, h_dim, num_pos, num_neg, comput_score, device) + +@pytest.mark.parametrize("h_dim", [16, 64]) +@pytest.mark.parametrize("num_pos", [8, 32]) +@pytest.mark.parametrize("num_neg", [1, 32]) +@pytest.mark.parametrize("device",["cpu", "cuda:0"]) +def test_LinkPredictContrastiveTransEDecoder_L2norm(h_dim, num_pos, num_neg, device): + th.manual_seed(1) + etype = ('a', 'r1', 'b') + gamma = 4. + norm = 'l2' + decoder = LinkPredictContrastiveTransEDecoder([etype], h_dim, gamma=gamma) + decoder.trained_rels[0] = 1 # trick the decoder + decoder.norm = norm + decoder = decoder.to(device) + rel_emb = decoder.get_relemb(etype).to(device) + + def comput_score(src_emb, dst_emb): + score = (src_emb + rel_emb) - dst_emb + return gamma - th.norm(score, p=2, dim=-1) + + check_forward(decoder, etype, h_dim, num_pos, num_neg, comput_score, device) + @pytest.mark.parametrize("h_dim", [16, 64]) @pytest.mark.parametrize("feat_dim", [8, 32]) @pytest.mark.parametrize("out_dim", [2, 32]) @@ -768,6 +999,14 @@ def test_EntityRegression(in_dim, out_dim): test_LinkPredictContrastiveRotatEDecoder(32, 8, 16, "cpu") test_LinkPredictContrastiveRotatEDecoder(16, 32, 32, "cuda:0") + test_LinkPredictTransEDecoder(16, 8, 1, "cpu") + test_LinkPredictTransEDecoder(16, 32, 32, "cuda:0") + + test_LinkPredictContrastiveTransEDecoder_L1norm(32, 8, 16, "cpu") + test_LinkPredictContrastiveTransEDecoder_L2norm(32, 8, 16, "cpu") + test_LinkPredictContrastiveTransEDecoder_L1norm(16, 32, 32, "cuda:0") + test_LinkPredictContrastiveTransEDecoder_L2norm(16, 32, 32, "cuda:0") + test_EntityRegression(8, 1) test_EntityRegression(8, 8)