diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index b8fda2592a..fecd79553f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -670,8 +670,7 @@ def forward( extra_attn_kwargs = {} if self.attn_impl == 'flash': - if flash_attn_padding_info is not None: - key_padding_mask = None + key_padding_mask = None extra_attn_kwargs = { 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), 'sliding_window_size': self.sliding_window_size,