Skip to content

Commit

Permalink
Add type hints for forward functions in Gemma2 (#35034)
Browse files Browse the repository at this point in the history
* feat: add gemma2 type hints

* fix: mask is optional
  • Loading branch information
jla524 authored Dec 2, 2024
1 parent 7b5f76e commit f41d5d8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 8 deletions.
38 changes: 34 additions & 4 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(config, query, key, value, mask, **_kwargs):
def eager_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)

Expand All @@ -192,7 +199,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs):
return attn_output, attn_weights


def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
def flash_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
Expand Down Expand Up @@ -229,7 +244,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.
return attn_output, None


def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
def flex_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
output_attentions: bool = False,
**_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = config.attn_logit_softcapping
score = soft_cap * torch.tanh(score / soft_cap)
Expand All @@ -255,7 +278,14 @@ def tanh_softcap(score, b, h, q_idx, kv_idx):
return attn_output, attn_weights


def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
def sdpa_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, None]:
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)

Expand Down
38 changes: 34 additions & 4 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,14 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
pass


def eager_attention_forward(config, query, key, value, mask, **_kwargs):
def eager_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
key_states = repeat_kv(key, config.num_key_value_groups)
value_states = repeat_kv(value, config.num_key_value_groups)

Expand All @@ -235,7 +242,15 @@ def eager_attention_forward(config, query, key, value, mask, **_kwargs):
return attn_output, attn_weights


def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs):
def flash_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
Expand Down Expand Up @@ -272,7 +287,15 @@ def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.
return attn_output, None


def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs):
def flex_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
output_attentions: bool = False,
**_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def tanh_softcap(score, b, h, q_idx, kv_idx):
soft_cap = config.attn_logit_softcapping
score = soft_cap * torch.tanh(score / soft_cap)
Expand All @@ -298,7 +321,14 @@ def tanh_softcap(score, b, h, q_idx, kv_idx):
return attn_output, attn_weights


def sdpa_attention_forward(config, query, key, value, mask, **_kwargs):
def sdpa_attention_forward(
config: Gemma2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
**_kwargs,
) -> Tuple[torch.Tensor, None]:
key = repeat_kv(key, config.num_key_value_groups)
value = repeat_kv(value, config.num_key_value_groups)

Expand Down

0 comments on commit f41d5d8

Please sign in to comment.