Skip to content

Commit

Permalink
Bug fix for GLEM LM forward pass with additional node feats (#611)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:* GLEM LM's forward pass breaks when there are
non-textual node features. This PR fixed it by subsetting the node
features of seed nodes.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
wangz10 authored Nov 2, 2023
1 parent 459692a commit 51b6ea6
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions python/graphstorm/model/node_glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,16 +457,21 @@ def predict(self, blocks, node_feats, edge_feats, input_nodes, return_proba):
preds = decoder.predict(emb)
return preds, emb

def _get_seed_nodes(self, input_nodes, blocks):
""" Get seed nodes from input nodes and labels of the seed nodes.
def _get_seed_nodes(self, input_nodes, node_feats, blocks):
""" Get seed nodes and features from input nodes and labels of the seed nodes.
Parameters
----------
input_nodes : {target_ntype: tensor.shape [bs], other_ntype: []}
node_feats : {ntype: tensor}
blocks : list[dgl.Block]
"""
target_ntype = self.target_ntype
n_seed_nodes = blocks[-1].num_dst_nodes()
return {target_ntype: input_nodes[target_ntype][:n_seed_nodes]}
seed_nodes = {target_ntype: input_nodes[target_ntype][:n_seed_nodes]}
seed_feats = {}
if target_ntype in node_feats:
seed_feats = {target_ntype: node_feats[target_ntype][:n_seed_nodes]}
return seed_nodes, seed_feats

def _embed_nodes(self, blocks, node_feats, _, input_nodes=None, do_gnn_encode=True):
""" Embed and encode nodes with LM, optionally followed by GNN encoder for GLEM model
Expand All @@ -480,9 +485,9 @@ def _embed_nodes(self, blocks, node_feats, _, input_nodes=None, do_gnn_encode=Tr
n_seed_nodes = blocks[-1].num_dst_nodes()
return encode_embs[target_ntype][:n_seed_nodes], encode_embs_gnn[target_ntype]
else:
# Get the projected LM embeddings for seed nodes:
seed_nodes = self._get_seed_nodes(input_nodes, blocks)
encode_embs = self.lm.comput_input_embed(seed_nodes, node_feats)
# Get the projected LM embeddings for seed nodes and corresponding node features:
seed_nodes, seed_feats = self._get_seed_nodes(input_nodes, node_feats, blocks)
encode_embs = self.lm.comput_input_embed(seed_nodes, seed_feats)
return encode_embs[target_ntype], None

def _process_labels(self, labels):
Expand Down

0 comments on commit 51b6ea6

Please sign in to comment.