Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] fix error of peft lora when xformers enabled #5697

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

Expand All @@ -891,17 +894,17 @@ def __call__(

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, scale=scale)
query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, scale=scale)
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
Expand All @@ -915,7 +918,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -946,6 +949,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

Expand All @@ -958,7 +964,7 @@ def __call__(

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states, scale=scale)
query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
Expand All @@ -967,8 +973,8 @@ def __call__(
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, scale=scale)
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
Expand All @@ -985,7 +991,7 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])

# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down Expand Up @@ -1177,6 +1183,8 @@ def __call__(
) -> torch.FloatTensor:
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -1207,12 +1215,8 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = (
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
Expand All @@ -1232,9 +1236,7 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = (
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

Expand Down
Loading