Skip to content

Commit

Permalink
Fix typing in load_balancing_loss_func function of `modeling_mixtra…
Browse files Browse the repository at this point in the history
…l.py`. (#33641)

* fix return type

* update to union

* fix gate_logits typing

* fix num_experts type

* fix typing

* run fix-copies

* add doc for top_k

* run fix-copies

* empty commit to trigger CI
  • Loading branch information
PhilipMay authored Sep 27, 2024
1 parent d3821c4 commit 2e24ee4
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 31 deletions.
16 changes: 11 additions & 5 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

# Copied from transformers.models.jetmoe.modeling_jetmoe.load_balancing_loss_func
def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -115,14 +118,17 @@ def load_balancing_loss_func(
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
gate_logits:
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
Expand Down
15 changes: 9 additions & 6 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@

# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
def load_balancing_loss_func(
router_logits: torch.Tensor,
num_experts: torch.Tensor = None,
router_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> float:
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -96,14 +96,17 @@ def load_balancing_loss_func(
experts is too unbalanced.
Args:
router_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
router_logits:
Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -120,14 +123,17 @@ def load_balancing_loss_func(
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
gate_logits:
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(


def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -134,14 +137,17 @@ def load_balancing_loss_func(
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
gate_logits:
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/models/olmoe/modeling_olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -117,14 +120,17 @@ def load_balancing_loss_func(
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
gate_logits:
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(

# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
def load_balancing_loss_func(
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
) -> float:
gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
num_experts: Optional[int] = None,
top_k=2,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, int]:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
Expand All @@ -127,14 +130,17 @@ def load_balancing_loss_func(
experts is too unbalanced.
Args:
gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
gate_logits:
Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
shape [batch_size X sequence_length, num_experts].
num_experts:
Number of experts
top_k:
The number of experts to route per-token, can be also interpreted as the `top-k` routing
parameter.
attention_mask (`torch.Tensor`, *optional*):
The attention_mask used in forward function
shape [batch_size X sequence_length] if not None.
num_experts (`int`, *optional*):
Number of experts
Returns:
The auxiliary loss.
Expand Down

0 comments on commit 2e24ee4

Please sign in to comment.