diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 6ae5725987..00bc61f7c4 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -236,7 +236,7 @@ def gen_flash_attn_padding_info( key_padding_mask = attention_mask if key_padding_mask is None: key_padding_mask = torch.ones((bsz, past_key_len + S), - dtype=torch.bool).to(device=device) + dtype=torch.bool, device=device) query_padding_mask = key_padding_mask[:, -S:] unpadding_function = bert_padding.unpad_input else: @@ -245,12 +245,12 @@ def gen_flash_attn_padding_info( unpadding_function = bert_padding.unpad_input_for_concatenated_sequences _, indices_q, cu_seqlens_q, max_seqlen_q = unpadding_function( - torch.zeros(bsz, S, 1).to(device=device), query_padding_mask) + torch.zeros(bsz, S, 1, device=device), query_padding_mask) _, indices_k, cu_seqlens_k, max_seqlen_k = unpadding_function( - torch.zeros(bsz, past_key_len + S, 1).to(device=device), + torch.zeros(bsz, past_key_len + S, 1, device=device), key_padding_mask) _, indices_v, _, _ = unpadding_function( - torch.zeros(bsz, past_key_len + S, 1).to(device=device), + torch.zeros(bsz, past_key_len + S, 1, device=device), key_padding_mask) flash_attn_padding_info['indices_q'] = indices_q