Skip to content

Commit

Permalink
FIX Adaption prompt error after transformers 35235
Browse files Browse the repository at this point in the history
The changes in huggingface/transformers#35235
resulted in a couple of adaption prompt tests to fail. This PR fixes
these failures while maintaining compatibility with older transformers
versions.

Required changes:

- hidden_size attribute removed from model, now config.hidden_size
- num_heads attribute removed from model, now config.num_attention_heads
- forward now returns 2 outputs instead of 3, rewritten to be agnostic
  towards the number of outputs
  • Loading branch information
BenjaminBossan committed Jan 8, 2025
1 parent 3d2bf9a commit b07c1e6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
18 changes: 13 additions & 5 deletions src/peft/tuners/adaption_prompt/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ def __init__(self, model_type: str, adapter_len: int, model):
target_dtype = (
model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32
)
if hasattr(self.model, "hidden_size"):
hidden_size = self.model.hidden_size
else: # changed in https://github.com/huggingface/transformers/pull/35235
hidden_size = self.model.config.hidden_size
self.adaption_prompt = nn.Parameter(
torch.empty(1, adapter_len, self.model.hidden_size, device=device, dtype=target_dtype).normal_()
torch.empty(1, adapter_len, hidden_size, device=device, dtype=target_dtype).normal_()
)
# Initialize the gate to 0 as this is "zero-init".
self.adaption_gate = nn.Parameter(torch.zeros(1, device=device, dtype=target_dtype))
Expand All @@ -67,7 +71,7 @@ def forward(self, **kwargs):
if kwargs.get("output_attention", False):
raise NotImplementedError("output_attention is not currently supported.")

output, _, past_key_value = self.model(**kwargs)
output, *_ = self.model(**kwargs)
bsz = output.shape[0]
q_len = output.shape[1]
embed_dim = output.shape[2]
Expand All @@ -84,14 +88,18 @@ def forward(self, **kwargs):
key = getattr(self.model, k_proj_layer)(self.adaption_prompt)
value = getattr(self.model, v_proj_layer)(self.adaption_prompt)

if hasattr(self.model, "num_heads"):
num_heads = self.model.num_heads
else: # changed in https://github.com/huggingface/transformers/pull/35235
num_heads = self.model.config.num_attention_heads
# (bsz, num_key_value_heads, adapter_len, head_dim)
adapter_k = (
key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
key.view(1, self.adapter_len, (num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
adapter_v = (
value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim)
value.view(1, self.adapter_len, (num_heads // factor), self.model.head_dim)
.repeat(bsz, 1, 1, 1)
.transpose(1, 2)
)
Expand Down Expand Up @@ -125,4 +133,4 @@ def forward(self, **kwargs):

# Restore original dtype.
output = output.to(previous_dtype)
return output, None, past_key_value
return output, *_
10 changes: 6 additions & 4 deletions src/peft/tuners/adaption_prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
if hasattr(model, "num_heads"):
num_heads = model.num_heads
else: # changed in https://github.com/huggingface/transformers/pull/35235
num_heads = model.config.num_attention_heads
query_states = model.q_proj(hidden_states).view(bsz, q_len, num_heads, model.head_dim).transpose(1, 2)

factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)
value_states = model.v_proj(hidden_states).view(bsz, q_len, (num_heads // factor), model.head_dim).transpose(1, 2)

seq_len = q_len

Expand Down

0 comments on commit b07c1e6

Please sign in to comment.