diff --git a/examples/peft_llm_gnn/llm_gnn_model.py b/examples/peft_llm_gnn/llm_gnn_model.py index 87ffe027e7..082aefad8b 100644 --- a/examples/peft_llm_gnn/llm_gnn_model.py +++ b/examples/peft_llm_gnn/llm_gnn_model.py @@ -25,8 +25,9 @@ from graphstorm import model as gsmodel from graphstorm.model.lm_model import TOKEN_IDX, ATT_MASK_IDX +from dgl.nn import GraphConv, HeteroGraphConv -from peft import LoraConfig, get_peft_model +from peft import LoraConfig, get_peft_model, AutoPeftModel from transformers import ( AutoConfig, AutoModel, @@ -115,32 +116,69 @@ def __init__(self, g, node_lm_configs, h_dim, out_dim, num_layers, self._lm_node_feats[ntype] = feats self.out = nn.Linear(self.config.hidden_size, out_dim) self._loss_fn = gsmodel.ClassifyLossFunc(multilabel=False) - #TODO (@qzhuamzn): add initialization for gnn encoding + self.gnn = nn.ModuleList() + for _ in range(num_layers): + self.gnn.append(HeteroGraphConv({ + _etype: GraphConv(h_dim, h_dim) for _etype in g.etypes + })) + self.projection = nn.Linear(h_dim, self.config.hidden_size) + + def encode_graph(self, blocks, h): + for layer, block in zip(self.gnn, blocks): + h = layer(block, h) + h = {k: F.relu(v) for k, v in h.items()} + src_type, dst_type = blocks[0].ntypes + graph_tokens = self.projection(h[dst_type]) + return graph_tokens def forward(self, blocks, node_feats, edge_feats, labels, input_nodes): - # TODO (qzhuamzn): use GNNs to generate graph tokens - h = {} output_nodes = blocks[-1].dstdata[dgl.NID] - input_ids = self._lm_node_feats[self.target_ntype][TOKEN_IDX][output_nodes].to(self.llm.device) attention_mask = self._lm_node_feats[self.target_ntype][ATT_MASK_IDX][output_nodes].to(self.llm.device) - # TODO (qzhuamzn): modify input_ids into input_embeds=[graph_tokens, input_embeds] to support GPEFT - model_output = self.llm(input_ids=input_ids, attention_mask=attention_mask) + + graph_tokens = self.encode_graph(blocks, node_feats) + + input_shape = input_ids.size() + # make sure input_ids are batch_size X seq_len + input_ids = input_ids.view(-1, input_shape[-1]) + word_embeddings = self.llm.get_input_embeddings() + # assuming graph_tokens has a shape of [batch_size, embedding_dim], + # input_ids has a shape of [batch_size, seq_len, embedding_dim] + # only one graph token is inserted ahead of input words + inputs_embeds = torch.cat([graph_tokens.unsqueeze(1), word_embeddings(input_ids)], dim=1) + + # enable attention computation on the inserted graph token + attention_mask = torch.cat([torch.ones((input_shape[0],1), device=self.llm.device), attention_mask], dim=1) + model_output = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask) # We use the last token in order to do the classification, as other causal models - h = self.out(model_output.last_hidden_state[:,-1,:]) + masked_hidden_states = model_output.last_hidden_state * attention_mask.unsqueeze(-1) + last_token_indexes = (attention_mask.sum(dim=1, dtype=torch.int64) - 1) + last_token_embeddings = masked_hidden_states[torch.arange(last_token_indexes.size(0)),last_token_indexes,:] + h = self.out(last_token_embeddings) + loss = self._loss_fn(h, labels[self.target_ntype]) return loss def predict(self, blocks, node_feats, _, input_nodes, return_proba): - # TODO (qzhuamzn): use h as gnn token embeddings - h = {} output_nodes = blocks[-1].dstdata[dgl.NID] + input_ids = self._lm_node_feats[self.target_ntype][TOKEN_IDX][output_nodes].to(self.llm.device) attention_mask = self._lm_node_feats[self.target_ntype][ATT_MASK_IDX][output_nodes].to(self.llm.device) - model_output = self.llm(input_ids=input_ids, attention_mask=attention_mask) - logits = self.out(model_output.last_hidden_state[:,-1,:]) + graph_tokens = self.encode_graph(blocks, node_feats) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + word_embeddings = self.llm.get_input_embeddings() + inputs_embeds = torch.cat([graph_tokens.unsqueeze(1), word_embeddings(input_ids)], dim=1) + attention_mask = torch.cat([torch.ones((input_shape[0],1), device=self.llm.device), attention_mask], dim=1) + model_output = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask) + masked_hidden_states = model_output.last_hidden_state * attention_mask.unsqueeze(-1) + last_token_indexes = (attention_mask.sum(dim=1, dtype=torch.int64) - 1) + last_token_embeddings = masked_hidden_states[torch.arange(last_token_indexes.size(0)),last_token_indexes,:] + logits = self.out(last_token_embeddings) + if return_proba: return logits.argmax(dim=-1), torch.softmax(logits, 1) else: @@ -154,7 +192,7 @@ def create_optimizer(self): return torch.optim.Adam(self.parameters(), lr=self.lr) def restore_model(self, restore_model_path): - self.llm = AutoModel.from_pretrained(restore_model_path, config=self.config) + self.llm = AutoPeftModel.from_pretrained(restore_model_path) def save_model(self, model_path): os.makedirs(model_path, exist_ok=True) diff --git a/examples/peft_llm_gnn/nc_config_Video_Games.yaml b/examples/peft_llm_gnn/nc_config_Video_Games.yaml index 626553d8c9..0e5f6f9768 100644 --- a/examples/peft_llm_gnn/nc_config_Video_Games.yaml +++ b/examples/peft_llm_gnn/nc_config_Video_Games.yaml @@ -11,13 +11,13 @@ gsf: verbose: false gnn: model_encoder_type: rgcn - fanout: 10, 5 + fanout: "5,5" hidden_size: 768 num_layers: 2 use_mini_batch_infer: true hyperparam: batch_size: 4 - dropout: 0.1 + dropout: 0.0 eval_batch_size: 4 lr: 0.0001 num_epochs: 10 @@ -31,7 +31,7 @@ gsf: node_classification: eval_metric: - accuracy - label_field: label + label_field: pt_lvl3 multilabel: false node_feat_name: - item:h