diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 237fba6f645..167ccd15580 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -307,7 +307,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index dac356146f3..4424e8b2fea 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -532,7 +532,7 @@ def eager_attention_forward( dim: int, output_attentions: Optional[bool] = False, **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor] | Tuple[torch.Tensor]: +) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] cos, sin = module.rotary_emb(qkv, position_ids=position_ids) query, key, value = qkv.transpose(3, 1).unbind(dim=2)