Skip to content

Commit

Permalink
yo
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Dec 16, 2024
1 parent 539e70d commit 7b12371
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 16 deletions.
6 changes: 1 addition & 5 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(
additional_train_metrics: Optional[list] = None,
additional_eval_metrics: Optional[list] = None,
should_save_peft_only: bool = True,
peft_is_trainable: bool = False,
):
config_overrides = config_overrides or {}

Expand All @@ -86,7 +85,6 @@ def __init__(
config_overrides=config_overrides,
load_in_8bit=load_in_8bit,
pretrained=pretrained,
peft_is_trainable=peft_is_trainable,
)

model = self.transform_model(model)
Expand Down Expand Up @@ -209,7 +207,6 @@ def build_inner_model(
pretrained: bool,
model_cls: Optional[Union[_BaseAutoModelClass, PreTrainedModel]] = None,
prepare_for_fsdp: bool = False,
peft_is_trainable: bool = False,
) -> Union[PreTrainedModel, 'PeftModel']:
"""Builds the inner model for the ComposerHFCausalLM.
Expand All @@ -225,7 +222,6 @@ def build_inner_model(
pretrained (bool): Whether the model is pretrained.
model_cls (Union[Type, Type[PreTrainedModel]]): Kept for backwards compatibility.
prepare_for_fsdp (bool, optional): Kept for backwards compatilbility.
peft_is_trainable (bool): Whether loaded PEFT adapters are trainable. Default: ``False``.
Returns:
Union[PreTrainedModel, 'PeftModel']: The built inner model.
Expand Down Expand Up @@ -396,7 +392,7 @@ def build_inner_model(
model = PeftModelForCausalLM.from_pretrained(
model,
pretrained_lora_id_or_path,
is_trainable=peft_is_trainable,
is_trainable=True,
)

if prepare_for_fsdp:
Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class ComposerHFCausalLM(BaseHuggingFaceModel):
init_device (str, optional): Which device to initialize the model on. Default: ``'cpu'``.
use_flash_attention_2 (bool, optional): Whether to use flash-attention 2. Default: ``False``.
tokenizer (PreTrainedTokenizer): The tokenizer that the model will use.
peft_is_trainable (bool): Whether loaded PEFT adapters are trainable. Default: ``False``.
"""

model_cls: Union[_BaseAutoModelClass,
Expand All @@ -80,7 +79,6 @@ def __init__(
additional_train_metrics: Optional[list] = None,
additional_eval_metrics: Optional[list] = None,
should_save_peft_only: bool = True,
peft_is_trainable: bool = False,
):
super().__init__(
pretrained_model_name_or_path,
Expand All @@ -100,5 +98,4 @@ def __init__(
additional_train_metrics=additional_train_metrics,
additional_eval_metrics=additional_eval_metrics,
should_save_peft_only=should_save_peft_only,
peft_is_trainable=peft_is_trainable,
)
10 changes: 2 additions & 8 deletions tests/models/hf/test_hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def test_build_inner_model_fsdp():
assert model.fsdp_wrap_fn(model.model.layers[0])


@pytest.mark.parametrize('trainable', [True, False])
def test_pretrained_peft_trainable(trainable: bool):
def test_pretrained_peft_trainable():
model = BaseHuggingFaceModel.build_inner_model(
pretrained_model_name_or_path='facebook/opt-350m',
pretrained_lora_id_or_path='ybelkada/opt-350m-lora',
Expand All @@ -40,15 +39,10 @@ def test_pretrained_peft_trainable(trainable: bool):
load_in_8bit=False,
pretrained=True,
prepare_for_fsdp=True,
peft_is_trainable=trainable,
)

assert isinstance(model, PeftModel)

n_trainable, n_all = model.get_nb_trainable_parameters()
assert n_all > 0

if trainable:
assert n_trainable > 0
else:
assert n_trainable == 0
assert n_trainable > 0

0 comments on commit 7b12371

Please sign in to comment.