Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RWForCausalLM/Falcon Support #15

Open
wants to merge 3 commits into
base: latestmerge
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import bigcode, gptj, gptneox, llama, opt, offload
from . import bigcode, gptj, gptneox, llama, opt, rw, offload
2 changes: 1 addition & 1 deletion gptq/modelutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
if type(module) in layers:
if any([isinstance(module, t) for t in layers]):
return {name: module}
res = {}
for name1, child in module.named_children():
Expand Down
162 changes: 159 additions & 3 deletions gptq/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXModel
from transformers.models.gptj.modeling_gptj import GPTJModel
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeModel
mpt_support = True
hf_bleeding_edge_found = True
try:
from hf_bleeding_edge.mpt.modeling_mpt import MPTModel
from hf_bleeding_edge.rw.modelling_RW import RWModel
except ImportError:
mpt_support = False
hf_bleeding_edge_found = False
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -1150,6 +1151,159 @@ def mpt_offload_forward(self, input_ids: torch.LongTensor, past_key_values: Opti
return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)


def rw_offload_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if past_key_values is None:
past_key_values = tuple([None] * len(self.h))

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

hidden_states = inputs_embeds

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)

if self.alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None

causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
(
block,
hidden_states,
attention_mask,
alibi,
layer_past,
) = offload_loop_start(self, self.h, i, hidden_states, attention_mask, alibi, layer_past)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)

offload_loop_end(
self, self.h, i
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
x = offload_cleanup(self, self.h, hidden_states)

# Add last hidden state
hidden_states = self.ln_f(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


def find_layers(module):
if "0" in dict(module.named_children()):
return None, module, []
Expand Down Expand Up @@ -1199,8 +1353,10 @@ def load_quant_offload(
type(m).forward = opt_offload_forward
elif type(m) == GPTBigCodeModel:
type(m).forward = bigcode_offload_forward
elif mpt_support and type(m) == MPTModel:
elif hf_bleeding_edge_found and type(m) == MPTModel:
type(m).forward = mpt_offload_forward
elif hf_bleeding_edge_found and type(m) == RWModel:
type(m).forward = rw_offload_forward
else:
raise RuntimeError(f"Model type {type(m)} not supported by CPU offloader")

Expand Down
Loading