Skip to content

Commit

Permalink
[BugFix] Fix missing node normalization for link prediction tasks in …
Browse files Browse the repository at this point in the history
…multi-task learning (#926)

*Issue #, if available:*
In multitask learning, when there is a training link prediction task
with contrastive loss, the loss may become NaN. This is because,
GraphStorm does not add proper node normalization for the gnn
embeddings.

*Description of changes:*
Fix the bug.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Xiang Song <[email protected]>
  • Loading branch information
classicsong and Xiang Song authored Jul 30, 2024
1 parent f4d5785 commit 1d1b21f
Show file tree
Hide file tree
Showing 12 changed files with 581 additions and 37 deletions.
59 changes: 59 additions & 0 deletions inference_scripts/mt_infer/ml_nc_lp_norm_with_mask_infer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
---
version: 1.0
gsf:
basic:
backend: gloo
verbose: false
save_perf_results_path: null
batch_size: 32
node_feat_name:
- user:feat
- movie:title
gnn:
model_encoder_type: rgcn
num_layers: 1
hidden_size: 32
use_mini_batch_infer: true
input:
restore_model_path: null
output:
save_model_path: null
save_embed_path: null
hyperparam:
dropout: 0.
lr: 0.001
no_validation: false
rgcn:
num_bases: -1
use_self_loop: true
use_node_embeddings: false
multi_task_learning:
- node_classification:
target_ntype: "movie"
label_field: "label"
multilabel: false
num_classes: 19
batch_size: 16 # will overwrite the global batch_size
mask_fields:
- "train_mask_c0" # node classification mask 0
- "val_mask_c0"
- "test_mask_c0"
eval_metric:
- "accuracy"
- link_prediction:
lp_loss_func: "contrastive"
num_negative_edges: 4
num_negative_edges_eval: 100
train_negative_sampler: joint
eval_etype:
- "user,rating,movie"
train_etype:
- "user,rating,movie"
exclude_training_targets: true
reverse_edge_types_map:
- user,rating,rating-rev,movie
batch_size: 128 # will overwrite the global batch_size
mask_fields:
- "train_mask_field_lp"
- null # empty means there is no validation mask
- "test_mask_field_lp"
65 changes: 58 additions & 7 deletions python/graphstorm/gconstruct/remap_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
BUILTIN_TASK_EDGE_CLASSIFICATION,
BUILTIN_TASK_EDGE_REGRESSION,
BUILTIN_TASK_NODE_CLASSIFICATION,
BUILTIN_TASK_NODE_REGRESSION)
BUILTIN_TASK_NODE_REGRESSION,
BUILTIN_TASK_LINK_PREDICTION)

GS_OUTPUT_FORMAT_PARQUET = "parquet"
GS_OUTPUT_FORMAT_CSV = "csv"
Expand Down Expand Up @@ -655,16 +656,28 @@ def _parse_gs_config(config):
node_id_mapping = os.path.join(os.path.dirname(part_config), "raw_id_mappings")
predict_dir = config.save_prediction_path
emb_dir = config.save_embed_path
task_emb_dirs = []

pred_ntypes = []
pred_etypes = []
if config.multi_tasks is not None:
node_predict_dirs = []
edge_predict_dirs = []
if predict_dir is None:
return node_id_mapping, None, emb_dir, pred_ntypes, pred_etypes
# multi-task setting
tasks = config.multi_tasks

for task in tasks:
task_config = task.task_config
task_id = task.task_id
if task.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
if task_config.lp_embed_normalizer is not None:
# There are link prediction node embedding normalizer
# Need to handled the normalized embeddings.
task_emb_dirs.append(task_id)

if predict_dir is None:
return node_id_mapping, None, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes

for task in tasks:
task_config = task.task_config
task_id = task.task_id
Expand All @@ -681,7 +694,7 @@ def _parse_gs_config(config):
edge_predict_dirs.append(pred_path)

predict_dir = (node_predict_dirs, edge_predict_dirs)
return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes
return node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes
else:
task_type = config.task_type
if task_type in (BUILTIN_TASK_EDGE_CLASSIFICATION, BUILTIN_TASK_EDGE_REGRESSION):
Expand All @@ -694,7 +707,7 @@ def _parse_gs_config(config):
pred_ntypes = pred_ntypes \
if isinstance(pred_ntypes, list) else [pred_ntypes]

return node_id_mapping, predict_dir, emb_dir, pred_ntypes, pred_etypes
return node_id_mapping, predict_dir, emb_dir, task_emb_dirs, pred_ntypes, pred_etypes

def main(args, gs_config_args):
""" main function
Expand All @@ -714,7 +727,7 @@ def main(args, gs_config_args):
gs_args, _ = gs_parser.parse_known_args(gs_config_args)
config = GSConfig(gs_args)
config.verify_arguments(False)
id_mapping_path, predict_dir, node_emb_dir, pred_ntypes, pred_etypes = \
id_mapping_path, predict_dir, node_emb_dir, task_emb_dirs, pred_ntypes, pred_etypes = \
_parse_gs_config(config)
else:
# Case 2: remap_result is called alone.
Expand All @@ -724,6 +737,10 @@ def main(args, gs_config_args):
id_mapping_path = args.node_id_mapping
predict_dir = args.prediction_dir
node_emb_dir = args.node_emb_dir
# We do not handle the case when there are task specific embeddings
# in multi-task learning, if remap_result is called alone.
# Users need to clean up the node_emb_dir themselves.
task_emb_dirs = []
pred_etypes = args.pred_etypes
pred_ntypes = args.pred_ntypes
if pred_etypes is not None:
Expand Down Expand Up @@ -773,7 +790,26 @@ def main(args, gs_config_args):

else: # There is no shared file system
emb_names = os.listdir(node_emb_dir)
emb_names = [e_name for e_name in emb_names if e_name != "emb_info.json"]
# In single task learning, the node embed dir looks like:
# emb_dir/
# ntype0
# ntype1
# ...
# emb_info.json
#
# In multi-task learning, the node embed dir looks like:
# emb_dir/
# ntype0
# ntype1
# ...
# emb_info.json
# task_id0/
# task_id1/
# ...
# We need to exclude both emb_info.json and task_id directories,
# when we are collecting node types with node embeddings.
emb_names = [e_name for e_name in emb_names \
if e_name not in task_emb_dirs + ["emb_info.json"]]

emb_ntypes = emb_names
else:
Expand Down Expand Up @@ -962,6 +998,21 @@ def main(args, gs_config_args):
output_func)
files_to_remove += emb_files_to_remove

for task_emb_dir in task_emb_dirs:
task_emb_dir = os.path.join(node_emb_dir, task_emb_dir)
# We need to do ID remapping for node embeddings
emb_files_to_remove = \
remap_node_emb(emb_ntypes,
task_emb_dir,
task_emb_dir,
out_chunk_size,
num_proc,
rank,
world_size,
with_shared_fs,
output_func)
files_to_remove += emb_files_to_remove

if len(pred_etypes) > 0:
if isinstance(predict_dir, tuple):
_, edge_predict_dirs = predict_dir
Expand Down
41 changes: 30 additions & 11 deletions python/graphstorm/inference/mt_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def infer(self, data,
"""
do_eval = self.evaluator is not None
sys_tracker.check('start inferencing')
self._model.eval()
model = self._model
model.eval()

# All the tasks share the same GNN encoder so the fanouts are same
# for different tasks.
Expand Down Expand Up @@ -133,13 +134,13 @@ def gen_embs(edge_mask=None):
# so the node embeddings are updated inplace.
if use_mini_batch_infer:
embs = do_mini_batch_inference(
self._model, data, batch_size=infer_batch_size,
model, data, batch_size=infer_batch_size,
fanout=fanout,
edge_mask=edge_mask,
task_tracker=self.task_tracker)
else:
embs = do_full_graph_inference(
self._model, data,
model, data,
fanout=fanout,
edge_mask=edge_mask,
task_tracker=self.task_tracker)
Expand All @@ -154,17 +155,29 @@ def gen_embs(edge_mask=None):
# before conducting prediction results.
if save_embed_path is not None:
logging.info("Saving node embeddings")
node_norm_methods = model.node_embed_norm_methods
# Save the original embs first
save_gsgnn_embeddings(g,
save_embed_path,
embs,
node_id_mapping_file=node_id_mapping_file,
save_embed_format=save_embed_format)
barrier()
for task_id, norm_method in node_norm_methods.items():
if norm_method is None:
continue
normed_embs = model.normalize_task_node_embs(task_id, embs, inplace=False)
save_embed_path = os.path.join(save_embed_path, task_id)
save_gsgnn_embeddings(g,
save_embed_path,
normed_embs,
node_id_mapping_file=node_id_mapping_file,
save_embed_format=save_embed_format)
sys_tracker.check('save embeddings')

# save relation embedding if any for link prediction tasks
if get_rank() == 0:
decoders = self._model.task_decoders
decoders = model.task_decoders
for task_id, decoder in decoders.items():
if isinstance(decoder, LinkPredictDistMultDecoder):
rel_emb_path = os.path.join(save_embed_path, task_id)
Expand All @@ -189,7 +202,7 @@ def gen_embs(edge_mask=None):
# and edge regression tasks.
pre_results = \
multi_task_mini_batch_predict(
self._model,
model,
emb=embs,
dataloaders=predict_test_loader.dataloaders,
task_infos=predict_test_loader.task_infos,
Expand All @@ -213,9 +226,9 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs):
if skip_last_self_loop is True:
# Turn off the last layer GNN's self-loop
# to compute node embeddings.
self._model.gnn_encoder.skip_last_selfloop()
model.gnn_encoder.skip_last_selfloop()
new_embs = gen_embs()
self._model.gnn_encoder.reset_last_selfloop()
model.gnn_encoder.reset_last_selfloop()
return new_embs
else:
# If skip_last_self_loop is False
Expand All @@ -231,11 +244,11 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs):
# Note(xiangsx): In DistDGl, as we are using the
# same dist tensor, the node embeddings
# are updated inplace.
nfeat_embs = gen_emb_for_nfeat_reconstruct(self._model, nfrecon_gen_embs)
nfeat_embs = gen_emb_for_nfeat_reconstruct(model, nfrecon_gen_embs)

nfeat_recon_results = \
multi_task_mini_batch_predict(
self._model,
model,
emb=nfeat_embs,
dataloaders=dataloaders,
task_infos=task_infos,
Expand All @@ -258,8 +271,14 @@ def nfrecon_gen_embs(skip_last_self_loop=False, node_embs=embs):

# For link prediction, do evaluation task by task.
lp_test_embs = gen_embs(edge_mask=task_info.task_config.train_mask)

decoder = self._model.task_decoders[task_info.task_id]
# normalize the node embedding if needed.
# we can do inplace normalization as embeddings are generated
# per lp task.
lp_test_embs = model.normalize_task_node_embs(task_info.task_id,
lp_test_embs,
inplace=True)

decoder = model.task_decoders[task_info.task_id]
ranking = run_lp_mini_batch_predict(decoder, lp_test_embs, dataloader, device)
pre_results[task_info.task_id] = ranking

Expand Down
Loading

0 comments on commit 1d1b21f

Please sign in to comment.