Skip to content

Commit

Permalink
[shardformer] fix llama error when transformers upgraded. (hpcaitech#…
Browse files Browse the repository at this point in the history
…5055)

* fix-llama

* Update llama.py
  • Loading branch information
flybird11111 authored Nov 16, 2023
1 parent 3e02154 commit 97cd0cd
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand All @@ -13,6 +13,11 @@

from colossalai.pipeline.stage_manager import PipelineStageManager

try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
LATEST_VERSION = True
except ImportError:
LATEST_VERSION = False

class LlamaPipelineForwards:
"""
Expand Down Expand Up @@ -97,9 +102,14 @@ def llama_model_forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if LATEST_VERSION:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
else:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down

0 comments on commit 97cd0cd

Please sign in to comment.