diff --git a/python/graphstorm/model/node_glem.py b/python/graphstorm/model/node_glem.py index 8dcdaf0d12..0148ddf0df 100644 --- a/python/graphstorm/model/node_glem.py +++ b/python/graphstorm/model/node_glem.py @@ -457,16 +457,21 @@ def predict(self, blocks, node_feats, edge_feats, input_nodes, return_proba): preds = decoder.predict(emb) return preds, emb - def _get_seed_nodes(self, input_nodes, blocks): - """ Get seed nodes from input nodes and labels of the seed nodes. + def _get_seed_nodes(self, input_nodes, node_feats, blocks): + """ Get seed nodes and features from input nodes and labels of the seed nodes. Parameters ---------- input_nodes : {target_ntype: tensor.shape [bs], other_ntype: []} + node_feats : {ntype: tensor} blocks : list[dgl.Block] """ target_ntype = self.target_ntype n_seed_nodes = blocks[-1].num_dst_nodes() - return {target_ntype: input_nodes[target_ntype][:n_seed_nodes]} + seed_nodes = {target_ntype: input_nodes[target_ntype][:n_seed_nodes]} + seed_feats = {} + if target_ntype in node_feats: + seed_feats = {target_ntype: node_feats[target_ntype][:n_seed_nodes]} + return seed_nodes, seed_feats def _embed_nodes(self, blocks, node_feats, _, input_nodes=None, do_gnn_encode=True): """ Embed and encode nodes with LM, optionally followed by GNN encoder for GLEM model @@ -480,9 +485,9 @@ def _embed_nodes(self, blocks, node_feats, _, input_nodes=None, do_gnn_encode=Tr n_seed_nodes = blocks[-1].num_dst_nodes() return encode_embs[target_ntype][:n_seed_nodes], encode_embs_gnn[target_ntype] else: - # Get the projected LM embeddings for seed nodes: - seed_nodes = self._get_seed_nodes(input_nodes, blocks) - encode_embs = self.lm.comput_input_embed(seed_nodes, node_feats) + # Get the projected LM embeddings for seed nodes and corresponding node features: + seed_nodes, seed_feats = self._get_seed_nodes(input_nodes, node_feats, blocks) + encode_embs = self.lm.comput_input_embed(seed_nodes, seed_feats) return encode_embs[target_ntype], None def _process_labels(self, labels):