Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 14, 2024
1 parent fe00863 commit 0f27850
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from accelerate import init_empty_weights
from composer.core.precision import Precision, get_precision_context
from composer.distributed.dist_strategy import prepare_fsdp_module
from composer.models.huggingface import maybe_get_underlying_model
from composer.models.huggingface import (
HuggingFaceModel,
maybe_get_underlying_model,
)
from composer.optim import DecoupledAdamW
from composer.utils import (
FSDPConfig,
Expand All @@ -39,7 +42,6 @@

from llmfoundry import ComposerHFCausalLM
from llmfoundry.layers_registry import norms
from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel
from llmfoundry.models.layers import build_alibi_bias
from llmfoundry.models.layers.attention import (
check_alibi_support,
Expand Down Expand Up @@ -2560,7 +2562,7 @@ def test_hf_init(
False,
)

model = BaseHuggingFaceModel(model, tokenizer)
model = HuggingFaceModel(model, tokenizer)

batch = gen_random_batch(batch_size, test_cfg)

Expand Down Expand Up @@ -2609,7 +2611,7 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2):

mpt = MPTForCausalLM(hf_config)

model = BaseHuggingFaceModel(mpt, tokenizer, shift_labels=True)
model = HuggingFaceModel(mpt, tokenizer, shift_labels=True)

model = model.to(test_cfg.device)
batch = gen_random_batch(batch_size, test_cfg)
Expand Down

0 comments on commit 0f27850

Please sign in to comment.