Skip to content

Commit

Permalink
Add MPT support
Browse files Browse the repository at this point in the history
  • Loading branch information
0cc4m committed May 6, 2023
1 parent 6d87688 commit 01a5990
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions gptq/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers.models.opt.modeling_opt import OPTModel
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXModel
from transformers.models.gptj.modeling_gptj import GPTJModel
from hf_bleeding_edge.mpt.modeling_mpt import MPTModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -870,6 +871,77 @@ def custom_forward(*inputs):
attentions=all_self_attns,
)

def mpt_offload_forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
return_dict = return_dict if return_dict is not None else self.config.return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if attention_mask is not None:
attention_mask = attention_mask.bool()
if prefix_mask is not None:
prefix_mask = prefix_mask.bool()
if not return_dict:
raise NotImplementedError('return_dict False is not implemented yet for MPT')
if output_attentions:
raise NotImplementedError('output_attentions is not implemented yet for MPT')
if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
raise NotImplementedError('MPT does not support training with left padding.')
if self.prefix_lm and prefix_mask is None:
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
elif self.attn_uses_sequence_id is False and sequence_id is not None:
warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
S = input_ids.size(1)
assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
tok_emb = self.wte(input_ids)
if self.alibi:
x = tok_emb
else:
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
past_position = past_key_values[0][0].size(1)
if S + past_position > self.config.max_seq_len:
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
if attention_mask is not None:
pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
pos_emb = self.wpe(pos)
x = tok_emb + pos_emb
if self.embedding_fraction == 1:
x = self.emb_drop(x)
else:
x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
assert isinstance(self.emb_drop, nn.Module)
x = self.emb_drop(x_shrunk)
(attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
if use_cache and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)]
all_hidden_states = () if output_hidden_states else None
for (b_idx, block) in enumerate(self.blocks):
(
block,
x,
attention_mask,
position_ids,
) = offload_loop_start(self, b_idx, x, attention_mask, None)

if output_hidden_states:
assert all_hidden_states is not None
all_hidden_states = all_hidden_states + (x,)
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
(x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value

offload_loop_end(
self, b_idx
)
hidden_states = offload_cleanup(self, x)
x = self.norm_f(x)
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)


def find_layers(module):
if "0" in dict(module.named_children()):
Expand Down Expand Up @@ -918,6 +990,8 @@ def load_quant_offload(
type(m).forward = gptj_offload_forward
elif type(m) == OPTModel:
type(m).forward = opt_offload_forward
elif type(m) == MPTModel:
type(m).forward = mpt_offload_forward
else:
raise RuntimeError(f"Model type {type(m)} not supported by CPU offloader")

Expand Down

0 comments on commit 01a5990

Please sign in to comment.