From f41d5d8f747f48849005d18dd1c04d5889f31c1b Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Mon, 2 Dec 2024 06:03:36 -0800 Subject: [PATCH] Add type hints for forward functions in Gemma2 (#35034) * feat: add gemma2 type hints * fix: mask is optional --- .../models/gemma2/modeling_gemma2.py | 38 +++++++++++++++++-- .../models/gemma2/modular_gemma2.py | 38 +++++++++++++++++-- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 5504caf1484139..5dd4ffe0c8ac75 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -170,7 +170,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(config, query, key, value, mask, **_kwargs): +def eager_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -192,7 +199,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs): return attn_output, attn_weights -def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): +def flash_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + target_dtype: torch.dtype = torch.float16, + **_kwargs, +) -> Tuple[torch.Tensor, None]: if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] @@ -229,7 +244,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. return attn_output, None -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): +def flex_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + output_attentions: bool = False, + **_kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) @@ -255,7 +278,14 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): return attn_output, attn_weights -def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): +def sdpa_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **_kwargs, +) -> Tuple[torch.Tensor, None]: key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 87e090aa8195cb..7236ae2f5c9f87 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -213,7 +213,14 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass -def eager_attention_forward(config, query, key, value, mask, **_kwargs): +def eager_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -235,7 +242,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs): return attn_output, attn_weights -def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): +def flash_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + target_dtype: torch.dtype = torch.float16, + **_kwargs, +) -> Tuple[torch.Tensor, None]: if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] @@ -272,7 +287,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch. return attn_output, None -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): +def flex_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + output_attentions: bool = False, + **_kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) @@ -298,7 +321,14 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): return attn_output, attn_weights -def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): +def sdpa_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **_kwargs, +) -> Tuple[torch.Tensor, None]: key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups)