diff --git a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py index b4450e4cde..99feec864e 100644 --- a/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/btlm_attn_hijack_flash.py @@ -36,7 +36,9 @@ def flashattn_attn( head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - softmax_scale = 1 / (key.size(-1) ** 0.5) if self.scale_attn_weights else None + softmax_scale = ( + 1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None + ) query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)