Skip to content

Commit

Permalink
Enable gnn-based promting for gpeft (#701)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*
Finish the todos in the original peft llm gnn PR (#673).

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
GentleZhu authored Jan 26, 2024
1 parent e1482e0 commit f50cffd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 16 deletions.
64 changes: 51 additions & 13 deletions examples/peft_llm_gnn/llm_gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

from graphstorm import model as gsmodel
from graphstorm.model.lm_model import TOKEN_IDX, ATT_MASK_IDX
from dgl.nn import GraphConv, HeteroGraphConv

from peft import LoraConfig, get_peft_model
from peft import LoraConfig, get_peft_model, AutoPeftModel
from transformers import (
AutoConfig,
AutoModel,
Expand Down Expand Up @@ -115,32 +116,69 @@ def __init__(self, g, node_lm_configs, h_dim, out_dim, num_layers,
self._lm_node_feats[ntype] = feats
self.out = nn.Linear(self.config.hidden_size, out_dim)
self._loss_fn = gsmodel.ClassifyLossFunc(multilabel=False)
#TODO (@qzhuamzn): add initialization for gnn encoding
self.gnn = nn.ModuleList()
for _ in range(num_layers):
self.gnn.append(HeteroGraphConv({
_etype: GraphConv(h_dim, h_dim) for _etype in g.etypes
}))
self.projection = nn.Linear(h_dim, self.config.hidden_size)

def encode_graph(self, blocks, h):
for layer, block in zip(self.gnn, blocks):
h = layer(block, h)
h = {k: F.relu(v) for k, v in h.items()}
src_type, dst_type = blocks[0].ntypes
graph_tokens = self.projection(h[dst_type])
return graph_tokens

def forward(self, blocks, node_feats, edge_feats, labels, input_nodes):
# TODO (qzhuamzn): use GNNs to generate graph tokens
h = {}
output_nodes = blocks[-1].dstdata[dgl.NID]

input_ids = self._lm_node_feats[self.target_ntype][TOKEN_IDX][output_nodes].to(self.llm.device)
attention_mask = self._lm_node_feats[self.target_ntype][ATT_MASK_IDX][output_nodes].to(self.llm.device)
# TODO (qzhuamzn): modify input_ids into input_embeds=[graph_tokens, input_embeds] to support GPEFT
model_output = self.llm(input_ids=input_ids, attention_mask=attention_mask)

graph_tokens = self.encode_graph(blocks, node_feats)

input_shape = input_ids.size()
# make sure input_ids are batch_size X seq_len
input_ids = input_ids.view(-1, input_shape[-1])
word_embeddings = self.llm.get_input_embeddings()
# assuming graph_tokens has a shape of [batch_size, embedding_dim],
# input_ids has a shape of [batch_size, seq_len, embedding_dim]
# only one graph token is inserted ahead of input words
inputs_embeds = torch.cat([graph_tokens.unsqueeze(1), word_embeddings(input_ids)], dim=1)

# enable attention computation on the inserted graph token
attention_mask = torch.cat([torch.ones((input_shape[0],1), device=self.llm.device), attention_mask], dim=1)
model_output = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
# We use the last token in order to do the classification, as other causal models
h = self.out(model_output.last_hidden_state[:,-1,:])
masked_hidden_states = model_output.last_hidden_state * attention_mask.unsqueeze(-1)
last_token_indexes = (attention_mask.sum(dim=1, dtype=torch.int64) - 1)
last_token_embeddings = masked_hidden_states[torch.arange(last_token_indexes.size(0)),last_token_indexes,:]
h = self.out(last_token_embeddings)

loss = self._loss_fn(h, labels[self.target_ntype])

return loss


def predict(self, blocks, node_feats, _, input_nodes, return_proba):
# TODO (qzhuamzn): use h as gnn token embeddings
h = {}
output_nodes = blocks[-1].dstdata[dgl.NID]

input_ids = self._lm_node_feats[self.target_ntype][TOKEN_IDX][output_nodes].to(self.llm.device)
attention_mask = self._lm_node_feats[self.target_ntype][ATT_MASK_IDX][output_nodes].to(self.llm.device)
model_output = self.llm(input_ids=input_ids, attention_mask=attention_mask)
logits = self.out(model_output.last_hidden_state[:,-1,:])
graph_tokens = self.encode_graph(blocks, node_feats)

input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
word_embeddings = self.llm.get_input_embeddings()
inputs_embeds = torch.cat([graph_tokens.unsqueeze(1), word_embeddings(input_ids)], dim=1)
attention_mask = torch.cat([torch.ones((input_shape[0],1), device=self.llm.device), attention_mask], dim=1)
model_output = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
masked_hidden_states = model_output.last_hidden_state * attention_mask.unsqueeze(-1)
last_token_indexes = (attention_mask.sum(dim=1, dtype=torch.int64) - 1)
last_token_embeddings = masked_hidden_states[torch.arange(last_token_indexes.size(0)),last_token_indexes,:]
logits = self.out(last_token_embeddings)

if return_proba:
return logits.argmax(dim=-1), torch.softmax(logits, 1)
else:
Expand All @@ -154,7 +192,7 @@ def create_optimizer(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)

def restore_model(self, restore_model_path):
self.llm = AutoModel.from_pretrained(restore_model_path, config=self.config)
self.llm = AutoPeftModel.from_pretrained(restore_model_path)

def save_model(self, model_path):
os.makedirs(model_path, exist_ok=True)
Expand Down
6 changes: 3 additions & 3 deletions examples/peft_llm_gnn/nc_config_Video_Games.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ gsf:
verbose: false
gnn:
model_encoder_type: rgcn
fanout: 10, 5
fanout: "5,5"
hidden_size: 768
num_layers: 2
use_mini_batch_infer: true
hyperparam:
batch_size: 4
dropout: 0.1
dropout: 0.0
eval_batch_size: 4
lr: 0.0001
num_epochs: 10
Expand All @@ -31,7 +31,7 @@ gsf:
node_classification:
eval_metric:
- accuracy
label_field: label
label_field: pt_lvl3
multilabel: false
node_feat_name:
- item:h
Expand Down

0 comments on commit f50cffd

Please sign in to comment.