Skip to content

Commit

Permalink
add_test
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Dec 16, 2024
1 parent d10d426 commit 539e70d
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/models/hf/test_hf_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import pytest
from llmfoundry.models.hf.hf_base import BaseHuggingFaceModel
from peft import PeftModel


def test_build_inner_model_fsdp():
Expand All @@ -23,3 +25,30 @@ def test_build_inner_model_fsdp():
)

assert model.fsdp_wrap_fn(model.model.layers[0])


@pytest.mark.parametrize('trainable', [True, False])
def test_pretrained_peft_trainable(trainable: bool):
model = BaseHuggingFaceModel.build_inner_model(
pretrained_model_name_or_path='facebook/opt-350m',
pretrained_lora_id_or_path='ybelkada/opt-350m-lora',
trust_remote_code=False,
init_device='cpu',
use_flash_attention_2=False,
use_auth_token=False,
config_overrides={},
load_in_8bit=False,
pretrained=True,
prepare_for_fsdp=True,
peft_is_trainable=trainable,
)

assert isinstance(model, PeftModel)

n_trainable, n_all = model.get_nb_trainable_parameters()
assert n_all > 0

if trainable:
assert n_trainable > 0
else:
assert n_trainable == 0

0 comments on commit 539e70d

Please sign in to comment.