From 1d1b21fdf96d1ee76fb6f31699b3afdac9e32b88 Mon Sep 17 00:00:00 2001 From: "xiang song(charlie.song)" Date: Mon, 29 Jul 2024 18:52:48 -0700 Subject: [PATCH] [BugFix] Fix missing node normalization for link prediction tasks in multi-task learning (#926) *Issue #, if available:* In multitask learning, when there is a training link prediction task with contrastive loss, the loss may become NaN. This is because, GraphStorm does not add proper node normalization for the gnn embeddings. *Description of changes:* Fix the bug. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Xiang Song --- .../ml_nc_lp_norm_with_mask_infer.yaml | 59 ++++++++ python/graphstorm/gconstruct/remap_result.py | 65 ++++++++- python/graphstorm/inference/mt_infer.py | 41 ++++-- python/graphstorm/model/multitask_gnn.py | 129 +++++++++++++++++- python/graphstorm/run/gsgnn_mt/gsgnn_mt.py | 37 ++++- .../graphstorm/run/gsgnn_mt/mt_infer_gnn.py | 12 +- python/graphstorm/trainer/mt_trainer.py | 6 + .../end2end-tests/graphstorm-mt/mgpu_test.sh | 47 ++++++- .../gconstruct/test_remap_result.py | 25 +++- tests/unit-tests/test_gnn.py | 125 ++++++++++++++++- tests/unit-tests/util.py | 6 + training_scripts/gsgnn_mt/ml_nc_lp_norm.yaml | 66 +++++++++ 12 files changed, 581 insertions(+), 37 deletions(-) create mode 100644 inference_scripts/mt_infer/ml_nc_lp_norm_with_mask_infer.yaml create mode 100644 training_scripts/gsgnn_mt/ml_nc_lp_norm.yaml diff --git a/inference_scripts/mt_infer/ml_nc_lp_norm_with_mask_infer.yaml b/inference_scripts/mt_infer/ml_nc_lp_norm_with_mask_infer.yaml new file mode 100644 index 0000000000..53e2d267dc --- /dev/null +++ b/inference_scripts/mt_infer/ml_nc_lp_norm_with_mask_infer.yaml @@ -0,0 +1,59 @@ +--- +version: 1.0 +gsf: + basic: + backend: gloo + verbose: false + save_perf_results_path: null + batch_size: 32 + node_feat_name: + - user:feat + - movie:title + gnn: + model_encoder_type: rgcn + num_layers: 1 + hidden_size: 32 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + hyperparam: + dropout: 0. + lr: 0.001 + no_validation: false + rgcn: + num_bases: -1 + use_self_loop: true + use_node_embeddings: false + multi_task_learning: + - node_classification: + target_ntype: "movie" + label_field: "label" + multilabel: false + num_classes: 19 + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_c0" # node classification mask 0 + - "val_mask_c0" + - "test_mask_c0" + eval_metric: + - "accuracy" + - link_prediction: + lp_loss_func: "contrastive" + num_negative_edges: 4 + num_negative_edges_eval: 100 + train_negative_sampler: joint + eval_etype: + - "user,rating,movie" + train_etype: + - "user,rating,movie" + exclude_training_targets: true + reverse_edge_types_map: + - user,rating,rating-rev,movie + batch_size: 128 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_lp" + - null # empty means there is no validation mask + - "test_mask_field_lp" \ No newline at end of file diff --git a/python/graphstorm/gconstruct/remap_result.py b/python/graphstorm/gconstruct/remap_result.py index d88219e1d5..e19eebf7d6 100644 --- a/python/graphstorm/gconstruct/remap_result.py +++ b/python/graphstorm/gconstruct/remap_result.py @@ -40,7 +40,8 @@ BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_NODE_CLASSIFICATION, - BUILTIN_TASK_NODE_REGRESSION) + BUILTIN_TASK_NODE_REGRESSION, + BUILTIN_TASK_LINK_PREDICTION) GS_OUTPUT_FORMAT_PARQUET = "parquet" GS_OUTPUT_FORMAT_CSV = "csv" @@ -655,16 +656,28 @@ def _parse_gs_config(config): node_id_mapping = os.path.join(os.path.dirname(part_config), "raw_id_mappings") predict_dir = config.save_prediction_path emb_dir = config.save_embed_path + task_emb_dirs = [] pred_ntypes = [] pred_etypes = [] if config.multi_tasks is not None: node_predict_dirs = [] edge_predict_dirs = [] - if predict_dir is None: - return node_id_mapping, None, emb_dir, pred_ntypes, pred_etypes # multi-task setting tasks = config.multi_tasks + + for task in tasks: + task_config = task.task_config + task_id = task.task_id + if task.task_type in [BUILTIN_TASK_LINK_PREDICTION]: + if task_config.lp_embed_normalizer is not None: + # There are link prediction node embedding normalizer + # Need to handled the normalized embeddings. + task_emb_dirs.append(task_id) + + if predict_dir is None: + return node_id_mapping, None, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes + for task in tasks: task_config = task.task_config task_id = task.task_id @@ -681,7 +694,7 @@ def _parse_gs_config(config): edge_predict_dirs.append(pred_path) predict_dir = (node_predict_dirs, edge_predict_dirs) - return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes + return node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes else: task_type = config.task_type if task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION): @@ -694,7 +707,7 @@ def _parse_gs_config(config): pred_ntypes = pred_ntypes \ if isinstance(pred_ntypes, list) else [pred_ntypes] - return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes + return node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes def main(args, gs_config_args): """ main function @@ -714,7 +727,7 @@ def main(args, gs_config_args): gs_args, _ = gs_parser.parse_known_args(gs_config_args) config = GSConfig(gs_args) config.verify_arguments(False) - id_mapping_path, predict_dir, node_emb_dir, pred_ntypes, pred_etypes = \ + id_mapping_path, predict_dir, node_emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = \ _parse_gs_config(config) else: # Case 2: remap_result is called alone. @@ -724,6 +737,10 @@ def main(args, gs_config_args): id_mapping_path = args.node_id_mapping predict_dir = args.prediction_dir node_emb_dir = args.node_emb_dir + # We do not handle the case when there are task specific embeddings + # in multi-task learning, if remap_result is called alone. + # Users need to clean up the node_emb_dir themselves. + task_emb_dirs = [] pred_etypes = args.pred_etypes pred_ntypes = args.pred_ntypes if pred_etypes is not None: @@ -773,7 +790,26 @@ def main(args, gs_config_args): else: # There is no shared file system emb_names = os.listdir(node_emb_dir) - emb_names = [e_name for e_name in emb_names if e_name != "emb_info.json"] + # In single task learning, the node embed dir looks like: + # emb_dir/ + # ntype0 + # ntype1 + # ... + # emb_info.json + # + # In multi-task learning, the node embed dir looks like: + # emb_dir/ + # ntype0 + # ntype1 + # ... + # emb_info.json + # task_id0/ + # task_id1/ + # ... + # We need to exclude both emb_info.json and task_id directories, + # when we are collecting node types with node embeddings. + emb_names = [e_name for e_name in emb_names \ + if e_name not in task_emb_dirs + ["emb_info.json"]] emb_ntypes = emb_names else: @@ -962,6 +998,21 @@ def main(args, gs_config_args): output_func) files_to_remove += emb_files_to_remove + for task_emb_dir in task_emb_dirs: + task_emb_dir = os.path.join(node_emb_dir, task_emb_dir) + # We need to do ID remapping for node embeddings + emb_files_to_remove = \ + remap_node_emb(emb_ntypes, + task_emb_dir, + task_emb_dir, + out_chunk_size, + num_proc, + rank, + world_size, + with_shared_fs, + output_func) + files_to_remove += emb_files_to_remove + if len(pred_etypes) > 0: if isinstance(predict_dir, tuple): _, edge_predict_dirs = predict_dir diff --git a/python/graphstorm/inference/mt_infer.py b/python/graphstorm/inference/mt_infer.py index 142c5636a4..05943ccb9b 100644 --- a/python/graphstorm/inference/mt_infer.py +++ b/python/graphstorm/inference/mt_infer.py @@ -105,7 +105,8 @@ def infer(self, data, """ do_eval = self.evaluator is not None sys_tracker.check('start inferencing') - self._model.eval() + model = self._model + model.eval() # All the tasks share the same GNN encoder so the fanouts are same # for different tasks. @@ -133,13 +134,13 @@ def gen_embs(edge_mask=None): # so the node embeddings are updated inplace. if use_mini_batch_infer: embs = do_mini_batch_inference( - self._model, data, batch_size=infer_batch_size, + model, data, batch_size=infer_batch_size, fanout=fanout, edge_mask=edge_mask, task_tracker=self.task_tracker) else: embs = do_full_graph_inference( - self._model, data, + model, data, fanout=fanout, edge_mask=edge_mask, task_tracker=self.task_tracker) @@ -154,17 +155,29 @@ def gen_embs(edge_mask=None): # before conducting prediction results. if save_embed_path is not None: logging.info("Saving node embeddings") + node_norm_methods = model.node_embed_norm_methods + # Save the original embs first save_gsgnn_embeddings(g, save_embed_path, embs, node_id_mapping_file=node_id_mapping_file, save_embed_format=save_embed_format) barrier() + for task_id, norm_method in node_norm_methods.items(): + if norm_method is None: + continue + normed_embs = model.normalize_task_node_embs(task_id, embs, inplace=False) + save_embed_path = os.path.join(save_embed_path, task_id) + save_gsgnn_embeddings(g, + save_embed_path, + normed_embs, + node_id_mapping_file=node_id_mapping_file, + save_embed_format=save_embed_format) sys_tracker.check('save embeddings') # save relation embedding if any for link prediction tasks if get_rank() == 0: - decoders = self._model.task_decoders + decoders = model.task_decoders for task_id, decoder in decoders.items(): if isinstance(decoder, LinkPredictDistMultDecoder): rel_emb_path = os.path.join(save_embed_path, task_id) @@ -189,7 +202,7 @@ def gen_embs(edge_mask=None): # and edge regression tasks. pre_results = \ multi_task_mini_batch_predict( - self._model, + model, emb=embs, dataloaders=predict_test_loader.dataloaders, task_infos=predict_test_loader.task_infos, @@ -213,9 +226,9 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs): if skip_last_self_loop is True: # Turn off the last layer GNN's self-loop # to compute node embeddings. - self._model.gnn_encoder.skip_last_selfloop() + model.gnn_encoder.skip_last_selfloop() new_embs = gen_embs() - self._model.gnn_encoder.reset_last_selfloop() + model.gnn_encoder.reset_last_selfloop() return new_embs else: # If skip_last_self_loop is False @@ -231,11 +244,11 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs): # Note(xiangsx): In DistDGl, as we are using the # same dist tensor, the node embeddings # are updated inplace. - nfeat_embs = gen_emb_for_nfeat_reconstruct(self._model, nfrecon_gen_embs) + nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs) nfeat_recon_results = \ multi_task_mini_batch_predict( - self._model, + model, emb=nfeat_embs, dataloaders=dataloaders, task_infos=task_infos, @@ -258,8 +271,14 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs): # For link prediction, do evaluation task by task. lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask) - - decoder = self._model.task_decoders[task_info.task_id] + # normalize the node embedding if needed. + # we can do inplace normalization as embeddings are generated + # per lp task. + lp_test_embs = model.normalize_task_node_embs(task_info.task_id, + lp_test_embs, + inplace=True) + + decoder = model.task_decoders[task_info.task_id] ranking = run_lp_mini_batch_predict(decoder, lp_test_embs, dataloader, device) pre_results[task_info.task_id] = ranking diff --git a/python/graphstorm/model/multitask_gnn.py b/python/graphstorm/model/multitask_gnn.py index 02a679eb70..f5b964e33e 100644 --- a/python/graphstorm/model/multitask_gnn.py +++ b/python/graphstorm/model/multitask_gnn.py @@ -19,6 +19,7 @@ import logging import torch as th from torch import nn +import dgl from ..config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, @@ -32,7 +33,14 @@ from .node_gnn import run_node_mini_batch_predict from .edge_gnn import run_edge_mini_batch_predict from .lp_gnn import run_lp_mini_batch_predict - +from .utils import LazyDistTensor +from .utils import normalize_node_embs, get_data_range +from ..utils import ( + get_rank, + get_world_size, + barrier, + create_dist_tensor +) class GSgnnMultiTaskModelInterface: """ The interface for GraphStorm multi-task learning. @@ -93,10 +101,108 @@ def __init__(self, alpha_l2norm): self._task_pool = {} self._decoder = nn.ModuleDict() self._loss_fn = nn.ModuleDict() + self._node_embed_norm_methods = {} self._warn_printed = False + def normalize_task_node_embs(self, task_id, embs, inplace=False): + """ Normalize node embeddings when needed. + + normalize_task_node_embs should be called when embs stores embeddings + of every node. + + Parameters + ---------- + task_id: str + Task ID. + embs: dict of Tensors + A dict of node embeddings. + inplace: bool + Whether to do inplace normalization. + + Returns + ------- + new_embs: dict of Tensors + Normalized node embeddings. + """ + if self._node_embed_norm_methods[task_id] is not None: + new_embs = {} + rank = get_rank() + world_size = get_world_size() + for key, emb in embs.items(): + if isinstance(emb, (dgl.distributed.DistTensor, LazyDistTensor)): + # If emb is a distributed tensor, multiple processes are doing + # embdding normalization concurrently. We need to split + # the task. (From full_graph_inference) + start, end = get_data_range(rank, world_size, len(emb)) + new_emb = emb if inplace else \ + create_dist_tensor(emb.shape, + emb.dtype, + name=f"{emb.name}_task_id", + part_policy=emb.part_policy, + persistent=True) + else: + # If emb is just a torch Tensor. do normalization directly. + # (From mini_batch_inference) + start, end = 0, len(emb) + new_emb = emb if inplace else th.clone(emb) + idx = start + while idx + 1024 < end: + new_emb[idx:idx+1024] = \ + self.minibatch_normalize_task_node_embs( + task_id, + {key:emb[idx:idx+1024]})[key] + idx += 1024 + new_emb[idx:end] = \ + self.minibatch_normalize_task_node_embs( + task_id, + {key:emb[idx:end]})[key] + barrier() + new_embs[key] = new_emb + return new_embs + else: + # If normalization method is None + # do nothing. + new_embs = embs + return new_embs + + # pylint: disable = arguments-differ + def minibatch_normalize_task_node_embs(self, task_id, embs): + """ Normalize node embeddings when needed for a mini-batch. + + minibatch_normalize_task_node_embs should be called in + forward() and predict(). + + Parameters + ---------- + task_id: str + Task ID. + embs: dict of Tensors + A dict of node embeddings. + + Returns + ------- + embs: dict of Tensors + Normalized node embeddings. + """ + if self._node_embed_norm_methods[task_id] is not None: + return normalize_node_embs(embs, self._node_embed_norm_methods[task_id]) + else: + return embs + + @property + def node_embed_norm_methods(self): + """ Get per task node embedding normalization method + + Returns + ------- + dict of strings: + Normalization methods + """ + return self._node_embed_norm_methods + def add_task(self, task_id, task_type, - decoder, loss_func): + decoder, loss_func, + embed_norm_method=None): """ Add a task into the multi-task pool Parameters @@ -112,6 +218,8 @@ def add_task(self, task_id, task_type, Task decoder. loss_func: func Loss function. + embed_norm_method: str + Node embedding normalization method. """ assert task_id not in self._task_pool, \ f"Task {task_id} already exists" @@ -120,6 +228,7 @@ def add_task(self, task_id, task_type, self._decoder[task_id] = decoder # add loss func in nn module self._loss_fn[task_id] = loss_func + self._node_embed_norm_methods[task_id] = embed_norm_method @property def alpha_l2norm(self): @@ -277,7 +386,7 @@ def _forward(self, task_id, encoder_data, decoder_data): encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) # Call emb normalization. - encode_embs = self.normalize_node_embs(encode_embs) + encode_embs = self.minibatch_normalize_task_node_embs(task_id, encode_embs) if task_type in [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION]: labels = decoder_data @@ -353,7 +462,7 @@ def predict(self, task_id, mini_batch, return_proba=False): encode_embs = self.compute_embed_step(blocks, node_feats, input_nodes) # Call emb normalization. - encode_embs = self.normalize_node_embs(encode_embs) + encode_embs = self.minibatch_normalize_task_node_embs(task_id, encode_embs) task_type, _ = self.task_pool[task_id] task_decoder = self.decoder[task_id] @@ -415,6 +524,18 @@ def multi_task_mini_batch_predict( res = {} with th.no_grad(): for dataloader, task_info in zip(dataloaders, task_infos): + # normalize the node embedding if needed. + # input emb is shared across different tasks + # so that we can not do inplace normalization. + # + # Note(xiangsx): Currently node embedding normalization + # only supports link prediction tasks. + # model.normalize_task_node_embs does nothing + # for node and edge prediction tasks. + # TODO(xiangsx): Need a more memory efficient design when + # node embedding normalization supports node and edge + # prediction tasks. + emb = model.normalize_task_node_embs(task_info.task_id, emb, inplace=False) if task_info.task_type in \ [BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, diff --git a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py index 19d63b77b0..304fabbe16 100644 --- a/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py +++ b/python/graphstorm/run/gsgnn_mt/gsgnn_mt.py @@ -348,7 +348,17 @@ def main(config_args): train_data.g, encoder_out_dims, train_task=True) - model.add_task(task.task_id, task.task_type, decoder, loss_func) + # For link prediction, lp_embed_normalizer may be used + # TODO(xiangsx): add embed normalizer for other task types + # in the future. + node_embed_norm_method = task.task_config.lp_embed_normalizer \ + if task.task_type in [BUILTIN_TASK_LINK_PREDICTION] \ + else None + model.add_task(task.task_id, + task.task_type, + decoder, + loss_func, + embed_norm_method=node_embed_norm_method) if not config.no_validation: if val_loader is None: logging.warning("The training data do not have validation set.") @@ -419,7 +429,14 @@ def main(config_args): train_data.g, encoder_out_dims, train_task=True) - model.add_task(task.task_id, task.task_type, decoder, loss_func) + node_embed_norm_method = task.task_config.lp_embed_normalizer \ + if task.task_type in [BUILTIN_TASK_LINK_PREDICTION] \ + else None + model.add_task(task.task_id, + task.task_type, + decoder, + loss_func, + embed_norm_method=node_embed_norm_method) best_model_path = trainer.get_best_model_path() # TODO(zhengda) the model path has to be in a shared filesystem. model.restore_model(best_model_path) @@ -432,6 +449,7 @@ def main(config_args): embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout, task_tracker=tracker) + # Save the original embs first save_full_node_embeddings( train_data.g, config.save_embed_path, @@ -439,6 +457,21 @@ def main(config_args): node_id_mapping_file=config.node_id_mapping_file, save_embed_format=config.save_embed_format) + node_norm_methods = model.node_embed_norm_methods + # save normalized embeddings + for task_id, norm_method in node_norm_methods.items(): + if norm_method is None: + continue + normed_embs = model.normalize_task_node_embs(task_id, embeddings, inplace=False) + save_embed_path = os.path.join(config.save_embed_path, task_id) + save_full_node_embeddings( + train_data.g, + save_embed_path, + normed_embs, + node_id_mapping_file=config.node_id_mapping_file, + save_embed_format=config.save_embed_format) + + def generate_parser(): """ Generate an argument parser """ diff --git a/python/graphstorm/run/gsgnn_mt/mt_infer_gnn.py b/python/graphstorm/run/gsgnn_mt/mt_infer_gnn.py index 718625f3f1..6c4122004b 100644 --- a/python/graphstorm/run/gsgnn_mt/mt_infer_gnn.py +++ b/python/graphstorm/run/gsgnn_mt/mt_infer_gnn.py @@ -218,7 +218,17 @@ def main(config_args): predict_dataloaders.append(data_loader) predict_tasks.append(task) - model.add_task(task.task_id, task.task_type, decoder, loss_func) + # For link prediction, lp_embed_normalizer may be used + # TODO(xiangsx): add embed normalizer for other task types + # in the future. + node_embed_norm_method = task.task_config.lp_embed_normalizer \ + if task.task_type in [BUILTIN_TASK_LINK_PREDICTION] \ + else None + model.add_task(task.task_id, + task.task_type, + decoder, + loss_func, + embed_norm_method=node_embed_norm_method) # Multi-task testing dataloader for node prediction and # edge prediction tasks. diff --git a/python/graphstorm/trainer/mt_trainer.py b/python/graphstorm/trainer/mt_trainer.py index a9e13ba0f8..2f2787baaa 100644 --- a/python/graphstorm/trainer/mt_trainer.py +++ b/python/graphstorm/trainer/mt_trainer.py @@ -641,6 +641,12 @@ def gen_embs(edge_mask=None): # For link prediction, do evaluation task # by task. lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask) + # normalize the node embedding if needed. + # we can do inplace normalization as embeddings are generated + # per lp task. + lp_test_embs = model.normalize_task_node_embs(task_info.task_id, + lp_test_embs, + inplace=True) decoder = model.task_decoders[task_info.task_id] val_scores = run_lp_mini_batch_predict(decoder, diff --git a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh index aceb326ac6..67eb3211c8 100644 --- a/tests/end2end-tests/graphstorm-mt/mgpu_test.sh +++ b/tests/end2end-tests/graphstorm-mt/mgpu_test.sh @@ -674,6 +674,49 @@ python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_m error_and_exit $? -rm -fr /data/gsgnn_mt/infer-emb/ -rm -fr /data/gsgnn_mt/prediction/ +rm -fr /data/gsgnn_mt/ rm -fr /tmp/infer_log.txt + + +echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference: full-graph, save model" +python3 -m graphstorm.run.gs_multi_task_learning --workspace $GS_HOME/training_scripts/gsgnn_mt --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_lp_norm.yaml --save-model-path /data/gsgnn_mt/ --save-model-frequency 1000 --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True --use-mini-batch-infer False --save-embed-path /data/gsgnn_mt/emb/ + +error_and_exit $? + +cnt=$(grep "save_embed_path: /data/gsgnn_mt/emb/" /tmp/train_log.txt | wc -l) +if test $cnt -lt 1 +then + echo "We use SageMaker task tracker, we should have save_embed_path" + exit -1 +fi + +cnt=$(ls -l /data/gsgnn_mt/emb/ | wc -l) +cnt=$[cnt - 1] +if test $cnt != 3 +then + echo "The number of saved embs $cnt is not equal to 3. Should have two for movie and user and One for link-prediction-subtask normalized embedding." +fi + +echo "**************[Multi-task] dataset: Movielens, RGCN layer 1, node feat: fixed HF BERT, BERT nodes: movie, inference with test" +python3 -m graphstorm.run.gs_multi_task_learning --inference --workspace $GS_HOME/inference_scripts/mt_infer --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_multi_task_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_lp_norm_with_mask_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_mt/infer-emb/ --restore-model-path /data/gsgnn_mt/epoch-2 --save-prediction-path /data/gsgnn_mt/prediction/ --logging-file /tmp/infer_log.txt --preserve-input True + +error_and_exit $? + +cnt=$(ls -l /data/gsgnn_mt/infer-emb/ | wc -l) +cnt=$[cnt - 2] +if test $cnt != 4 +then + echo "The number of saved embs $cnt is not equal to 3. Should have two for movie and user and One for link-prediction-subtask normalized embedding." +fi + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_mt/emb/ --infer-embout /data/gsgnn_mt/infer-emb/ + +error_and_exit $? + +python3 $GS_HOME/tests/end2end-tests/check_infer.py --train-embout /data/gsgnn_mt/emb/link_prediction-user_rating_movie --infer-embout /data/gsgnn_mt/infer-emb/link_prediction-user_rating_movie + +error_and_exit $? + +rm -fr /data/gsgnn_mt/ +rm -fr /tmp/train_log.txt +rm -fr /tmp/infer_log.txt \ No newline at end of file diff --git a/tests/unit-tests/gconstruct/test_remap_result.py b/tests/unit-tests/gconstruct/test_remap_result.py index 7ac1c7156e..ca48a81586 100644 --- a/tests/unit-tests/gconstruct/test_remap_result.py +++ b/tests/unit-tests/gconstruct/test_remap_result.py @@ -26,6 +26,7 @@ from numpy.testing import assert_equal, assert_almost_equal from graphstorm.config import GSConfig +from graphstorm.config.config import get_mttask_id from graphstorm.config import (BUILTIN_TASK_NODE_CLASSIFICATION, BUILTIN_TASK_NODE_REGRESSION, BUILTIN_TASK_EDGE_CLASSIFICATION, @@ -409,13 +410,14 @@ def test_parse_config(): setattr(config, "_task_type", BUILTIN_TASK_NODE_CLASSIFICATION) setattr(config, "_target_ntype", target_ntype) setattr(config, "_multi_tasks", None) - node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes = _parse_gs_config(config) + node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = _parse_gs_config(config) assert node_id_mapping == os.path.join(tmpdirname, "raw_id_mappings") assert predict_dir == save_prediction_path assert emb_dir == save_embed_path assert len(pred_ntypes) == 1 assert pred_ntypes[0] == target_ntype assert len(pred_etypes) == 0 + assert len(task_emb_dirs) == 0 target_etype = ["n0,r0,n1"] config = GSConfig.__new__(GSConfig) @@ -426,13 +428,14 @@ def test_parse_config(): setattr(config, "_target_etype", target_etype) setattr(config, "_multi_tasks", None) - node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes = _parse_gs_config(config) + node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = _parse_gs_config(config) assert node_id_mapping == os.path.join(tmpdirname, "raw_id_mappings") assert predict_dir == save_prediction_path assert emb_dir == save_embed_path assert len(pred_ntypes) == 0 assert len(pred_etypes) == 1 assert pred_etypes[0] == ["n0", "r0", "n1"] + assert len(task_emb_dirs) == 0 # multi-task config multi_task_config = [ @@ -470,9 +473,10 @@ def test_parse_config(): "link_prediction" : { "num_negative_edges": 4, "batch_size": 128, - "exclude_training_targets": False + "exclude_training_targets": False, + "lp_embed_normalizer": "l2_norm" } - } + }, ] config = GSConfig.__new__(GSConfig) @@ -480,7 +484,7 @@ def test_parse_config(): setattr(config, "_save_prediction_path", save_prediction_path) setattr(config, "_save_embed_path", save_embed_path) config._parse_multi_tasks(multi_task_config) - node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes = _parse_gs_config(config) + node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = _parse_gs_config(config) assert node_id_mapping == os.path.join(tmpdirname, "raw_id_mappings") assert isinstance(predict_dir, tuple) @@ -498,14 +502,20 @@ def test_parse_config(): assert len(pred_etypes) == 2 assert pred_etypes[0] == ['n0', 'r0', 'r1'] assert pred_etypes[1] == ['n0', 'r0', 'r2'] + print(task_emb_dirs) + assert len(task_emb_dirs) == 1 + assert task_emb_dirs[0] == get_mttask_id( + task_type="link_prediction", + etype="ALL_ETYPE") # there is no predict path # it will use emb_path + multi_task_config[4]["link_prediction"].pop("lp_embed_normalizer") config = GSConfig.__new__(GSConfig) setattr(config, "_part_config", part_path) setattr(config, "_save_embed_path", save_embed_path) config._parse_multi_tasks(multi_task_config) - node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes = _parse_gs_config(config) + node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = _parse_gs_config(config) assert node_id_mapping == os.path.join(tmpdirname, "raw_id_mappings") assert isinstance(predict_dir, tuple) node_predict_dirs, edge_predict_dirs = predict_dir @@ -515,12 +525,13 @@ def test_parse_config(): assert node_predict_dirs[1] == os.path.join(save_embed_path, config.multi_tasks[1].task_id) assert edge_predict_dirs[0] == os.path.join(save_embed_path, config.multi_tasks[2].task_id) assert edge_predict_dirs[1] == os.path.join(save_embed_path, config.multi_tasks[3].task_id) + assert len(task_emb_dirs) == 0 # there is no predict path and emb path config = GSConfig.__new__(GSConfig) setattr(config, "_part_config", part_path) config._parse_multi_tasks(multi_task_config) - node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes = _parse_gs_config(config) + node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = _parse_gs_config(config) assert predict_dir is None assert emb_dir is None diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index eaeda9f48a..4141651ce4 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -31,6 +31,7 @@ from numpy.testing import assert_almost_equal, assert_equal import dgl +from dgl.distributed import DistTensor from graphstorm.config import GSConfig, TaskInfo from graphstorm.config import BUILTIN_LP_DOT_DECODER @@ -39,7 +40,8 @@ BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION, BUILTIN_TASK_LINK_PREDICTION, - BUILTIN_TASK_RECONSTRUCT_NODE_FEAT) + BUILTIN_TASK_RECONSTRUCT_NODE_FEAT, + GRAPHSTORM_LP_EMB_L2_NORMALIZATION) from graphstorm.model import GSNodeEncoderInputLayer, RelationalGCNEncoder from graphstorm.model import GSgnnNodeModel, GSgnnEdgeModel from graphstorm.model import GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer @@ -1815,6 +1817,121 @@ class DummyLPPredLoss(nn.Module): def forward(self, pos_score, neg_score): return pos_score["n0"] + neg_score["n0"] +def test_multi_task_norm_node_embs(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + mt_model.add_task("nc_task", + BUILTIN_TASK_NODE_CLASSIFICATION, + DummyNCDecoder(), + DummyPredLoss(), + "") + mt_model.add_task("nr_task", + BUILTIN_TASK_NODE_REGRESSION, + DummyNRDecoder(), + DummyPredLoss(), + GRAPHSTORM_LP_EMB_L2_NORMALIZATION) + + mt_model.add_task("ec_task", + BUILTIN_TASK_EDGE_CLASSIFICATION, + DummyECDecoder(), + DummyPredLoss(), + "") + + mt_model.add_task("er_task", + BUILTIN_TASK_EDGE_REGRESSION, + DummyERDecoder(), + DummyPredLoss(), + GRAPHSTORM_LP_EMB_L2_NORMALIZATION) + + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + DummyLPPredLoss(), + "") + + mt_model.add_task("lp_task2", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + DummyLPPredLoss(), + GRAPHSTORM_LP_EMB_L2_NORMALIZATION) + + embs = { + "n0": th.rand((10,16)), + "n1": th.rand((20,16)) + } + norm_embs = { + "n0": F.normalize(embs["n0"]), + "n1": F.normalize(embs["n1"]) + } + + new_embs = mt_model.normalize_task_node_embs("nc_task", embs, inplace=False) + assert_equal(embs["n0"].numpy(), new_embs["n0"].numpy()) + assert_equal(embs["n1"].numpy(), new_embs["n1"].numpy()) + + new_embs = mt_model.normalize_task_node_embs("nr_task", embs, inplace=False) + assert_equal(norm_embs["n0"].numpy(), new_embs["n0"].numpy()) + assert_equal(norm_embs["n1"].numpy(), new_embs["n1"].numpy()) + + new_embs = mt_model.normalize_task_node_embs("ec_task", embs, inplace=False) + assert_equal(embs["n0"].numpy(), new_embs["n0"].numpy()) + assert_equal(embs["n1"].numpy(), new_embs["n1"].numpy()) + + new_embs = mt_model.normalize_task_node_embs("er_task", embs, inplace=False) + assert_equal(norm_embs["n0"].numpy(), new_embs["n0"].numpy()) + assert_equal(norm_embs["n1"].numpy(), new_embs["n1"].numpy()) + + inplace_emb = { + "n0": th.clone(embs["n0"]), + "n1": th.clone(embs["n1"]) + } + mt_model.normalize_task_node_embs("lp_task", inplace_emb, inplace=True) + assert_equal(embs["n0"].numpy(), inplace_emb["n0"].numpy()) + assert_equal(embs["n1"].numpy(), inplace_emb["n1"].numpy()) + + mt_model.normalize_task_node_embs("lp_task2", inplace_emb, inplace=True) + assert_equal(norm_embs["n0"].numpy(), inplace_emb["n0"].numpy()) + assert_equal(norm_embs["n1"].numpy(), inplace_emb["n1"].numpy()) + +def test_multi_task_norm_node_embs_dist(): + mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) + mt_model.add_task("lp_task", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + DummyLPPredLoss(), + "") + + mt_model.add_task("lp_task2", + BUILTIN_TASK_LINK_PREDICTION, + DummyLPDecoder(), + DummyLPPredLoss(), + GRAPHSTORM_LP_EMB_L2_NORMALIZATION) + + with tempfile.TemporaryDirectory() as tmpdirname: + # get the test dummy distributed graph + g, _ = generate_dummy_dist_graph(tmpdirname, size="tiny") + + embs = {} + norm_embs = {} + dist_embs = {} + + for ntype in g.ntypes: + embs[ntype] = th.rand(g.number_of_nodes(ntype), 16) + norm_embs[ntype] = F.normalize(embs[ntype]) + dist_embs[ntype] = DistTensor((g.number_of_nodes(ntype), 16), + dtype=th.float32, name=f'ntype-{ntype}', + part_policy=g.get_node_partition_policy(ntype)) + dist_embs[ntype][th.arange(g.number_of_nodes(ntype))] = embs[ntype][:] + + new_embs = mt_model.normalize_task_node_embs("lp_task", dist_embs, inplace=False) + for ntype in g.ntypes: + assert_equal(embs[ntype].numpy(), new_embs[ntype][th.arange(g.number_of_nodes(ntype))].numpy()) + + new_embs = mt_model.normalize_task_node_embs("lp_task2", dist_embs, inplace=False) + for ntype in g.ntypes: + assert_equal(norm_embs[ntype].numpy(), new_embs[ntype][th.arange(g.number_of_nodes(ntype))].numpy()) + + dgl.distributed.kvstore.close_kvstore() + + def test_multi_task_forward(): mt_model = GSgnnMultiTaskSharedEncoderModel(0.1) @@ -1850,7 +1967,7 @@ def check_forward(mock_normalize_node_embs, mock_compute_emb, mock_input_embed): - def normalize_size_effect_func(embs): + def normalize_size_effect_func(task_id, embs): return embs def compute_side_effect_func(blocks, node_feats, input_nodes): @@ -1981,7 +2098,7 @@ def check_forward(mock_normalize_node_embs, mock_compute_emb, mock_input_embed): - def normalize_size_effect_func(embs): + def normalize_size_effect_func(task_id, embs): return embs def compute_side_effect_func(blocks, node_feats, input_nodes): @@ -2315,6 +2432,8 @@ def check_call_gen_embs(skip_last_self_loop): if __name__ == '__main__': test_node_feat_reconstruct() + test_multi_task_norm_node_embs() + test_multi_task_norm_node_embs_dist() test_multi_task_forward() test_multi_task_predict() test_multi_task_mini_batch_predict() diff --git a/tests/unit-tests/util.py b/tests/unit-tests/util.py index 963f68f23d..39d3672dee 100644 --- a/tests/unit-tests/util.py +++ b/tests/unit-tests/util.py @@ -108,6 +108,12 @@ def __init__(self, encoder_model, decoders, has_sparse=False): super(DummyGSgnnMTModel, self).__init__(encoder_model, has_sparse) self._decoders = decoders + @property + def node_embed_norm_methods(self): + return {} + + def normalize_task_node_embs(self, task_id, embs, inplace=False): + return embs def forward(self, task_mini_batches): pass diff --git a/training_scripts/gsgnn_mt/ml_nc_lp_norm.yaml b/training_scripts/gsgnn_mt/ml_nc_lp_norm.yaml new file mode 100644 index 0000000000..261a1c6106 --- /dev/null +++ b/training_scripts/gsgnn_mt/ml_nc_lp_norm.yaml @@ -0,0 +1,66 @@ +--- +version: 1.0 +gsf: + basic: + backend: gloo + verbose: false + save_perf_results_path: null + batch_size: 32 + node_feat_name: + - user:feat + - movie:title + gnn: + model_encoder_type: rgcn + fanout: "4" + num_layers: 1 + hidden_size: 32 + use_mini_batch_infer: true + input: + restore_model_path: null + output: + save_model_path: null + save_embed_path: null + hyperparam: + dropout: 0. + lr: 0.001 + lm_tune_lr: 0.0001 + num_epochs: 3 + wd_l2norm: 0 + no_validation: false + rgcn: + num_bases: -1 + use_self_loop: true + sparse_optimizer_lr: 1e-2 + use_node_embeddings: false + multi_task_learning: + - node_classification: + target_ntype: "movie" + label_field: "label" + multilabel: false + num_classes: 19 + batch_size: 16 # will overwrite the global batch_size + mask_fields: + - "train_mask_c0" # node classification mask 0 + - "val_mask_c0" + - "test_mask_c0" + task_weight: 1.0 + eval_metric: + - "accuracy" + - link_prediction: + lp_loss_func: "contrastive" + num_negative_edges: 4 + num_negative_edges_eval: 100 + train_negative_sampler: joint + eval_etype: + - "user,rating,movie" + train_etype: + - "user,rating,movie" + exclude_training_targets: true + reverse_edge_types_map: + - user,rating,rating-rev,movie + batch_size: 128 # will overwrite the global batch_size + mask_fields: + - "train_mask_field_lp" + - "val_mask_field_lp" + - null # empty means there is no test mask + task_weight: 1.0