Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gnn-based promting for gpeft #701

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions examples/peft_llm_gnn/llm_gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

from graphstorm import model as gsmodel
from graphstorm.model.lm_model import TOKEN_IDX, ATT_MASK_IDX
from dgl.nn import GraphConv, HeteroGraphConv

from peft import LoraConfig, get_peft_model
from peft import LoraConfig, get_peft_model, AutoPeftModel
from transformers import (
AutoConfig,
AutoModel,
Expand Down Expand Up @@ -115,32 +116,64 @@ def __init__(self, g, node_lm_configs, h_dim, out_dim, num_layers,
self._lm_node_feats[ntype] = feats
self.out = nn.Linear(self.config.hidden_size, out_dim)
self._loss_fn = gsmodel.ClassifyLossFunc(multilabel=False)
#TODO (@qzhuamzn): add initialization for gnn encoding
self.gnn = nn.ModuleList()
for _ in range(num_layers):
self.gnn.append(HeteroGraphConv({
_etype: GraphConv(h_dim, h_dim) for _etype in g.etypes
}))
self.projection = nn.Linear(h_dim, self.config.hidden_size)

def encode_graph(self, blocks, h):
for layer, block in zip(self.gnn, blocks):
h = layer(block, h)
h = {k: F.relu(v) for k, v in h.items()}
src_type, dst_type = blocks[0].ntypes
graph_tokens = self.projection(h[dst_type])
return graph_tokens

def forward(self, blocks, node_feats, edge_feats, labels, input_nodes):
# TODO (qzhuamzn): use GNNs to generate graph tokens
h = {}
output_nodes = blocks[-1].dstdata[dgl.NID]

input_ids = self._lm_node_feats[self.target_ntype][TOKEN_IDX][output_nodes].to(self.llm.device)
attention_mask = self._lm_node_feats[self.target_ntype][ATT_MASK_IDX][output_nodes].to(self.llm.device)
# TODO (qzhuamzn): modify input_ids into input_embeds=[graph_tokens, input_embeds] to support GPEFT
model_output = self.llm(input_ids=input_ids, attention_mask=attention_mask)

graph_tokens = self.encode_graph(blocks, node_feats)

input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
word_embeddings = self.llm.get_input_embeddings()
inputs_embeds = torch.cat([graph_tokens.unsqueeze(1), word_embeddings(input_ids)], dim=1)
GentleZhu marked this conversation as resolved.
Show resolved Hide resolved

attention_mask = torch.cat([torch.ones((input_shape[0],1), device=self.llm.device), attention_mask], dim=1)
GentleZhu marked this conversation as resolved.
Show resolved Hide resolved
model_output = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
# We use the last token in order to do the classification, as other causal models
h = self.out(model_output.last_hidden_state[:,-1,:])
masked_hidden_states = model_output.last_hidden_state * attention_mask.unsqueeze(-1)
last_token_indexes = (attention_mask.sum(dim=1, dtype=torch.int64) - 1)
last_token_embeddings = masked_hidden_states[torch.arange(last_token_indexes.size(0)),last_token_indexes,:]
h = self.out(last_token_embeddings)

loss = self._loss_fn(h, labels[self.target_ntype])

return loss


def predict(self, blocks, node_feats, _, input_nodes, return_proba):
# TODO (qzhuamzn): use h as gnn token embeddings
h = {}
output_nodes = blocks[-1].dstdata[dgl.NID]

input_ids = self._lm_node_feats[self.target_ntype][TOKEN_IDX][output_nodes].to(self.llm.device)
attention_mask = self._lm_node_feats[self.target_ntype][ATT_MASK_IDX][output_nodes].to(self.llm.device)
model_output = self.llm(input_ids=input_ids, attention_mask=attention_mask)
logits = self.out(model_output.last_hidden_state[:,-1,:])
graph_tokens = self.encode_graph(blocks, node_feats)

input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
word_embeddings = self.llm.get_input_embeddings()
inputs_embeds = torch.cat([graph_tokens.unsqueeze(1), word_embeddings(input_ids)], dim=1)
attention_mask = torch.cat([torch.ones((input_shape[0],1), device=self.llm.device), attention_mask], dim=1)
model_output = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
masked_hidden_states = model_output.last_hidden_state * attention_mask.unsqueeze(-1)
last_token_indexes = (attention_mask.sum(dim=1, dtype=torch.int64) - 1)
last_token_embeddings = masked_hidden_states[torch.arange(last_token_indexes.size(0)),last_token_indexes,:]
logits = self.out(last_token_embeddings)

if return_proba:
return logits.argmax(dim=-1), torch.softmax(logits, 1)
else:
Expand All @@ -154,7 +187,7 @@ def create_optimizer(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)

def restore_model(self, restore_model_path):
self.llm = AutoModel.from_pretrained(restore_model_path, config=self.config)
self.llm = AutoPeftModel.from_pretrained(restore_model_path)

def save_model(self, model_path):
os.makedirs(model_path, exist_ok=True)
Expand Down
6 changes: 3 additions & 3 deletions examples/peft_llm_gnn/nc_config_Video_Games.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ gsf:
verbose: false
gnn:
model_encoder_type: rgcn
fanout: 10, 5
fanout: "5,5"
hidden_size: 768
num_layers: 2
use_mini_batch_infer: true
hyperparam:
batch_size: 4
dropout: 0.1
dropout: 0.0
eval_batch_size: 4
lr: 0.0001
num_epochs: 10
Expand All @@ -31,7 +31,7 @@ gsf:
node_classification:
eval_metric:
- accuracy
label_field: label
label_field: pt_lvl3
multilabel: false
node_feat_name:
- item:h
Expand Down
Loading