Skip to content

Commit

Permalink
clean-ups
Browse files Browse the repository at this point in the history
  • Loading branch information
wangz10 committed Nov 27, 2023
1 parent a5a9db4 commit 4a79017
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions python/graphstorm/model/edge_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,7 @@ def calc_retrieval_scores(self, emb, pos_pairs, device):
scores = {}
pos_scores = calc_dot_pos_score(pos_src_emb, pos_dst_emb)
neg_dst_emb = emb[vtype][np.arange(emb[vtype].shape[0])].to(device)
# neg_dst_emb should contains train nodes only:
# v_train_mask = g.nodes[vtype].data['train_mask'][np.arange(g.number_of_nodes(vtype))]
# train_nids = np.where(v_train_mask)[0]
# neg_dst_emb = emb[vtype][train_nids].to(device)
neg_scores = th.mm(pos_src_emb, neg_dst_emb.transpose(0, 1)) # [n_pos, n_train]
neg_scores = th.mm(pos_src_emb, neg_dst_emb.transpose(0, 1)) # [n_pos, n_embs]
# gloo with cpu will consume less GPU memory
neg_scores = neg_scores.cpu() \
if is_distributed() and get_backend() == "gloo" \
Expand Down

0 comments on commit 4a79017

Please sign in to comment.