diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index b247d21dc8..91d5947676 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -203,8 +203,8 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int, if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl == 'flash'): # Check if sequence has left padding. If yes, raise an error. - if (attention_mask is not None) and (attention_mask[:, 0].sum() - != attention_mask.shape[0]): + if (attention_mask is not None) and (attention_mask[:, 0].sum() != + attention_mask.shape[0]): raise NotImplementedError( 'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.' ) @@ -471,8 +471,8 @@ def _attn_bias( # clamp to 0 necessary for torch 2.0 compile() _s_k = max(0, attn_bias.size(-1) - s_k) attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and (attention_mask.shape - != prefix_mask.shape): + if prefix_mask is not None and (attention_mask.shape != + prefix_mask.shape): raise ValueError( f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') @@ -610,8 +610,8 @@ def forward( past_position = past_key_values[0][0].size(3) if self.learned_pos_emb or self.rope: - if self.learned_pos_emb and (S + past_position - > self.config.max_seq_len): + if self.learned_pos_emb and (S + past_position > + self.config.max_seq_len): raise ValueError( f'Cannot forward input with past sequence length {past_position} and current sequence length ' +