diff --git a/tests/models/test_model.py b/tests/models/test_model.py index e2ddb0a012..43067f5e47 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -15,7 +15,10 @@ from accelerate import init_empty_weights from composer.core.precision import Precision, get_precision_context from composer.distributed.dist_strategy import prepare_fsdp_module -from composer.models.huggingface import maybe_get_underlying_model +from composer.models.huggingface import ( + HuggingFaceModel, + maybe_get_underlying_model, +) from composer.optim import DecoupledAdamW from composer.utils import ( FSDPConfig, @@ -39,7 +42,6 @@ from llmfoundry import ComposerHFCausalLM from llmfoundry.layers_registry import norms -from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel from llmfoundry.models.layers import build_alibi_bias from llmfoundry.models.layers.attention import ( check_alibi_support, @@ -2560,7 +2562,7 @@ def test_hf_init( False, ) - model = BaseHuggingFaceModel(model, tokenizer) + model = HuggingFaceModel(model, tokenizer) batch = gen_random_batch(batch_size, test_cfg) @@ -2609,7 +2611,7 @@ def test_head_dim_8_flash_mqa_attn(batch_size: int = 2): mpt = MPTForCausalLM(hf_config) - model = BaseHuggingFaceModel(mpt, tokenizer, shift_labels=True) + model = HuggingFaceModel(mpt, tokenizer, shift_labels=True) model = model.to(test_cfg.device) batch = gen_random_batch(batch_size, test_cfg)