Skip to content

Commit

Permalink
fix embedding bug on link prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Oct 5, 2023
1 parent 96bdaf8 commit ebe0d4b
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/graphstorm/run/gsgnn_emb/gsgnn_node_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ def main(config_args):
# For example pre-compute all BERT embeddings
model.prepare_input_encoder(train_data)
# TODO(zhengda) we may not want to only use training edges to generate GNN embeddings.
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
task_tracker=tracker)
if config.task_type == BUILTIN_TASK_LINK_PREDICTION:
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
edge_mask="train_mask", task_tracker=tracker)
else:
embeddings = do_full_graph_inference(model, train_data, fanout=config.eval_fanout,
task_tracker=tracker)
save_embeddings(config.save_embed_path, embeddings, gs.get_rank(),
gs.get_world_size(),
device=device,
Expand Down

0 comments on commit ebe0d4b

Please sign in to comment.