Skip to content

Commit

Permalink
[Bugfix] Fix eval_fanout of validation dataloader of lp task in multi…
Browse files Browse the repository at this point in the history
…-task learning
  • Loading branch information
Xiang Song committed Jun 7, 2024
1 parent 0981cee commit 1f907ca
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions python/graphstorm/run/gsgnn_mt/gsgnn_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,23 @@ def create_task_val_dataloader(task, config, train_data):
elif task.task_type in [BUILTIN_TASK_LINK_PREDICTION]:
val_idxs = train_data.get_edge_val_set(task_config.eval_etype, mask=task_config.val_mask)
dataloader_cls = gs.get_builtin_lp_eval_dataloader_class(task_config)
# All tasks share the same GNN model, so the fanout should be the global fanout
fanout = config.eval_fanout if task_config.use_mini_batch_infer else []
if len(val_idxs) > 0:
# TODO(xiangsx): Support construct feat
if task_config.eval_etypes_negative_dstnode is not None:
return dataloader_cls(train_data, val_idxs,
task_config.eval_batch_size,
fixed_edge_dst_negative_field=task_config.eval_etypes_negative_dstnode,
fanout=task_config.eval_fanout,
fanout=fanout,
fixed_test_size=task_config.fixed_test_size,
node_feats=node_feats,
pos_graph_edge_feats=task_config.lp_edge_weight_for_loss)
else:
return dataloader_cls(train_data, val_idxs,
task_config.eval_batch_size,
task_config.num_negative_edges_eval,
fanout=task_config.eval_fanout,
fanout=fanout,
fixed_test_size=task_config.fixed_test_size,
node_feats=node_feats,
pos_graph_edge_feats=task_config.lp_edge_weight_for_loss)
Expand Down
2 changes: 1 addition & 1 deletion training_scripts/gsgnn_mt/ml_nc_ec_er_lp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ gsf:
batch_size: 128 # will overwrite the global batch_size
mask_fields:
- "train_mask_field_lp"
- null # empty means there is no validation mask
- "val_mask_field_l" # empty means there is no validation mask
- null # empty means there is no test mask
task_weight: 1.0
- reconstruct_node_feat:
Expand Down

0 comments on commit 1f907ca

Please sign in to comment.