Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Feb 6, 2024
1 parent 459700a commit 3b12441
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ def gen_attention_mask_in_length(sequence_id: Union[None, torch.Tensor], S: int,
if (sequence_id is not None) and attn_uses_sequence_id and (attn_impl
== 'flash'):
# Check if sequence has left padding. If yes, raise an error.
if (attention_mask is not None) and (attention_mask[:, 0].sum()
!= attention_mask.shape[0]):
if (attention_mask is not None) and (attention_mask[:, 0].sum() !=
attention_mask.shape[0]):
raise NotImplementedError(
'Left padding is not supported with flash attention when attn_uses_sequence_id is set to True.'
)
Expand Down Expand Up @@ -471,8 +471,8 @@ def _attn_bias(
# clamp to 0 necessary for torch 2.0 compile()
_s_k = max(0, attn_bias.size(-1) - s_k)
attn_bias = attn_bias[:, :, :, _s_k:]
if prefix_mask is not None and (attention_mask.shape
!= prefix_mask.shape):
if prefix_mask is not None and (attention_mask.shape !=
prefix_mask.shape):
raise ValueError(
f'attention_mask shape={attention_mask.shape} ' +
f'and prefix_mask shape={prefix_mask.shape} are not equal.')
Expand Down Expand Up @@ -610,8 +610,8 @@ def forward(
past_position = past_key_values[0][0].size(3)

if self.learned_pos_emb or self.rope:
if self.learned_pos_emb and (S + past_position
> self.config.max_seq_len):
if self.learned_pos_emb and (S + past_position >
self.config.max_seq_len):
raise ValueError(
f'Cannot forward input with past sequence length {past_position} and current sequence length '
+
Expand Down

0 comments on commit 3b12441

Please sign in to comment.