diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 407ba5e80b6..a75b995f05b 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -306,6 +306,7 @@ def eager_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, @@ -320,9 +321,21 @@ def eager_attention_forward( scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: + expanded_mask = _prepare_4d_attention_mask(attention_mask, attn_weights.dtype, tgt_len=key.shape[-2]) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(expanded_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf")) + + attn_weights = attn_weights + expanded_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -385,6 +398,7 @@ def flex_attention_forward( rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, block_mask: "BlockMask", + local_attention: Tuple[int, int], max_seqlen: int, bs: int, dim: int, @@ -401,7 +415,7 @@ def flex_attention_forward( query, key, value, - block_mask=block_mask, + block_mask=block_mask if local_attention != (-1, -1) else None, enable_gqa=False, scale=None, return_lse=False, @@ -416,6 +430,7 @@ def sdpa_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, @@ -427,7 +442,22 @@ def sdpa_attention_forward( query, key = apply_rotary_pos_emb(query, key, cos, sin) if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attention_mask = attention_mask[:, None, None, :].expand( + attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1] + ) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + attention_mask = torch.logical_and(attention_mask, window_mask) + + attention_mask = attention_mask.to(torch.bool) attn_output = F.scaled_dot_product_attention( query, @@ -893,7 +923,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets): counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) - @torch.compile(dynamic=False) def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): """ Creates a block mask combining sequence masking and local/or global attention masking. @@ -1053,23 +1082,12 @@ def forward( hidden_states = self.embeddings(input_ids) - # expand attention_mask - if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - # create block mask + block_mask = None if self.config._attn_implementation == "flex_attention": sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) - - if self.config.local_attention != (-1, -1): - window_size = self.config.local_attention // 2 - else: - window_size = max_seqlen - + window_size = self.config.local_attention // 2 block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) - else: - block_mask = None for encoder_layer in self.layers: if output_hidden_states: @@ -1082,6 +1100,7 @@ def forward( attention_mask, position_ids, cu_seqlens, + block_mask, max_seqlen, output_attentions, ) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 0b9ec85a045..d19419f6b93 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -518,6 +518,7 @@ def eager_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, output_attentions: Optional[bool] = False, @@ -532,9 +533,21 @@ def eager_attention_forward( scale = module.head_dim**-0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key.shape[-2]] - attn_weights = attn_weights + causal_mask + if attention_mask is not None: + expanded_mask = _prepare_4d_attention_mask(attention_mask, attn_weights.dtype, tgt_len=key.shape[-2]) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(expanded_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + expanded_mask.masked_fill_(window_mask.logical_not(), float("-inf")) + + attn_weights = attn_weights + expanded_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) @@ -597,6 +610,7 @@ def flex_attention_forward( rotary_emb: ModernBertUnpaddedRotaryEmbedding, cu_seqlens: torch.Tensor, block_mask: "BlockMask", + local_attention: Tuple[int, int], max_seqlen: int, bs: int, dim: int, @@ -613,7 +627,7 @@ def flex_attention_forward( query, key, value, - block_mask=block_mask, + block_mask=block_mask if local_attention != (-1, -1) else None, enable_gqa=False, scale=None, return_lse=False, @@ -628,6 +642,7 @@ def sdpa_attention_forward( qkv: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], bs: int, dim: int, **_kwargs, @@ -639,7 +654,22 @@ def sdpa_attention_forward( query, key = apply_rotary_pos_emb(query, key, cos, sin) if attention_mask is not None: - attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attention_mask = attention_mask[:, None, None, :].expand( + attention_mask.shape[0], 1, attention_mask.shape[1], attention_mask.shape[1] + ) + + if local_attention != (-1, -1): + # Create position indices + rows = torch.arange(attention_mask.shape[2]).unsqueeze(0) + # Calculate distance between positions + distance = torch.abs(rows - rows.T) + + # Create sliding window mask (1 for positions within window, 0 outside) + window_mask = (distance <= local_attention[0]).unsqueeze(0).unsqueeze(0).to(attention_mask.device) + # Combine with existing mask + attention_mask = torch.logical_and(attention_mask, window_mask) + + attention_mask = attention_mask.to(torch.bool) attn_output = F.scaled_dot_product_attention( query, @@ -1033,7 +1063,6 @@ def offsets_to_sequence_ids_tensor(cls, offsets): counts = offsets[1:] - offsets[:-1] return torch.repeat_interleave(torch.arange(len(counts), device=device, dtype=torch.int32), counts) - @torch.compile(dynamic=False) def create_attention_mask(self, sequence_ids, cu_seqlens, window_size): """ Creates a block mask combining sequence masking and local/or global attention masking. @@ -1193,23 +1222,12 @@ def forward( hidden_states = self.embeddings(input_ids) - # expand attention_mask - if self.config._attn_implementation != "flash_attention_2" and attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - # create block mask + block_mask = None if self.config._attn_implementation == "flex_attention": sequence_ids = self.offsets_to_sequence_ids_tensor(cu_seqlens) - - if self.config.local_attention != (-1, -1): - window_size = self.config.local_attention // 2 - else: - window_size = max_seqlen - + window_size = self.config.local_attention // 2 block_mask = self.create_attention_mask(sequence_ids, cu_seqlens, window_size) - else: - block_mask = None for encoder_layer in self.layers: if output_hidden_states: @@ -1222,6 +1240,7 @@ def forward( attention_mask, position_ids, cu_seqlens, + block_mask, max_seqlen, output_attentions, )