Skip to content

Commit

Permalink
Merge pull request #961 from rockerBOO/attention-processor
Browse files Browse the repository at this point in the history
Add attention processor
  • Loading branch information
kohya-ss authored Dec 3, 2023
2 parents 81a440c + c856ea4 commit 46cf41c
Showing 1 changed file with 43 additions and 4 deletions.
47 changes: 43 additions & 4 deletions library/original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,9 @@ def __init__(
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False

# Attention processor
self.processor = None

def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
Expand All @@ -607,7 +610,28 @@ def reshape_batch_dim_to_heads(self, tensor):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def forward(self, hidden_states, context=None, mask=None):
def set_processor(self):
return self.processor

def get_processor(self):
return self.processor

def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
Expand Down Expand Up @@ -720,6 +744,21 @@ def forward_sdpa(self, x, context=None, mask=None):
out = self.to_out[0](out)
return out

def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states

# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask

return hidden_states, context, mask

# feedforward
class GEGLU(nn.Module):
Expand Down Expand Up @@ -1350,7 +1389,7 @@ def __init__(
self.out_channels = OUT_CHANNELS

self.sample_size = sample_size
self.prepare_config()
self.prepare_config(sample_size=sample_size)

# state_dictの書式が変わるのでmoduleの持ち方は変えられない

Expand Down Expand Up @@ -1437,8 +1476,8 @@ def __init__(
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)

# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace(**kwargs)

@property
def dtype(self) -> torch.dtype:
Expand Down

0 comments on commit 46cf41c

Please sign in to comment.