diff --git a/llmfoundry/models/hf/hf_base.py b/llmfoundry/models/hf/hf_base.py index 10b49d1ec8..f5b40a6631 100644 --- a/llmfoundry/models/hf/hf_base.py +++ b/llmfoundry/models/hf/hf_base.py @@ -392,6 +392,7 @@ def build_inner_model( model = PeftModelForCausalLM.from_pretrained( model, pretrained_lora_id_or_path, + is_trainable=True, ) if prepare_for_fsdp: diff --git a/tests/models/hf/test_hf_base.py b/tests/models/hf/test_hf_base.py index 4b0fb34e53..67e9750a8a 100644 --- a/tests/models/hf/test_hf_base.py +++ b/tests/models/hf/test_hf_base.py @@ -1,6 +1,8 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from peft import PeftModel + from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel @@ -23,3 +25,24 @@ def test_build_inner_model_fsdp(): ) assert model.fsdp_wrap_fn(model.model.layers[0]) + + +def test_pretrained_peft_trainable(): + model = BaseHuggingFaceModel.build_inner_model( + pretrained_model_name_or_path='facebook/opt-350m', + pretrained_lora_id_or_path='ybelkada/opt-350m-lora', + trust_remote_code=False, + init_device='cpu', + use_flash_attention_2=False, + use_auth_token=False, + config_overrides={}, + load_in_8bit=False, + pretrained=True, + prepare_for_fsdp=True, + ) + + assert isinstance(model, PeftModel) + + n_trainable, n_all = model.get_nb_trainable_parameters() + assert n_all > 0 + assert n_trainable > 0