From 03a1c57da8783f1e04b7e023aee0c5207f1e836b Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 19 Jan 2024 12:07:57 -0800 Subject: [PATCH] wip --- llmfoundry/callbacks/hf_checkpointer.py | 4 +++- llmfoundry/models/hf/hf_causal_lm.py | 21 +++++++++++++++------ llmfoundry/models/hf/hf_fsdp.py | 13 ++++++------- llmfoundry/models/hf/model_wrapper.py | 8 +++++--- tests/models/hf/test_hf_peft_wrapping.py | 12 +++++++++--- 5 files changed, 38 insertions(+), 20 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 43ff77f100..1e7b38f2ec 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -226,7 +226,9 @@ def _save_checkpoint(self, state: State, logger: Logger): base_model = original_model.get_base_model() new_base_model_instance = type(base_model)(copied_config) - new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config['default']) + new_model_instance = type(original_model)( + new_base_model_instance, + original_model.peft_config['default']) else: new_model_instance = type(original_model)(copied_config) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index bb2137bf59..13fac845d1 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -6,7 +6,7 @@ import logging import os import warnings -from typing import Mapping, Union +from typing import Mapping # required for loading a python model into composer import transformers @@ -19,7 +19,6 @@ LanguageCrossEntropy, LanguagePerplexity) from composer.utils import dist from omegaconf import DictConfig -from torch import nn from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase) @@ -259,11 +258,21 @@ def _autoset_attn_implementation_monkeypatch( raise ValueError( f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}' ) - + from peft import LoraConfig - peft_config = pop_config(om_model_config, 'peft_config', must_exist=False, convert=True) - peft_type = peft_config.pop('peft_type', None) - peft_config = LoraConfig(**peft_config) + peft_config = pop_config(om_model_config, + 'peft_config', + must_exist=False, + convert=True) + + if peft_config is not None: + peft_type = peft_config.get('peft_type', None) + if peft_type.upper() != 'LORA': + raise ValueError(f'Only LORA is supported for peft_type, but got {peft_type}.') + task_type = peft_config.get('task_type', None) + if task_type.upper() != 'CAUSAL_LM': + raise ValueError(f'Only CAUSAL_LM is supported for task_type, but got {task_type}.') + peft_config = LoraConfig(**peft_config) composer_model = super().__init__( model=model, diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index e4e1ad094b..a618c1f9c6 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -5,17 +5,14 @@ # which is MIT licensed import functools -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, List, Optional, Union, TYPE_CHECKING import torch from transformers import PreTrainedModel from transformers.models.opt.modeling_opt import OPTDecoder -try: +if TYPE_CHECKING: from peft import PeftModel - peft_model_type = PeftModel -except ImportError: - peft_model_type = None # helper functions @@ -135,7 +132,8 @@ def prepare_hf_model_for_fsdp(model: PreTrainedModel, prepare_hf_causal_lm_model_for_fsdp(model, init_device) -def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, peft_model_type], +def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, + 'PeftModel'], init_device: Optional[str]) -> None: """FSDP wrap a HuggingFace decoder. @@ -207,7 +205,8 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, peft_model peft_type = model.peft_type.lower() active_adapters = [adapter.lower() for adapter in model.active_adapters] for name, module in model.named_modules(): - if peft_type in name.lower() and any(adapter in name.lower() for adapter in active_adapters): + if peft_type in name.lower() and any( + adapter in name.lower() for adapter in active_adapters): has_parameters = any(True for _ in module.parameters()) has_buffers = any(True for _ in module.buffers()) if has_parameters or has_buffers: diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index d84b96b02f..7010a7c244 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -5,9 +5,8 @@ from __future__ import annotations -import inspect from collections import UserDict -from typing import List, Mapping, Optional +from typing import List, Mapping, Optional, TYPE_CHECKING import torch import transformers @@ -18,6 +17,9 @@ from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp +if TYPE_CHECKING: + from peft import PeftConfig + # HuggingFace hardcodes the ignore index to -100 _HF_IGNORE_INDEX = -100 @@ -46,7 +48,7 @@ def __init__(self, z_loss: float = 0.0, shift_labels: bool = False, init_device: Optional[str] = None, - peft_config = None): + peft_config: Optional['PeftConfig'] = None): super().__init__(model, tokenizer, use_logits=True, diff --git a/tests/models/hf/test_hf_peft_wrapping.py b/tests/models/hf/test_hf_peft_wrapping.py index 1a9b2a5a11..881a075cc4 100644 --- a/tests/models/hf/test_hf_peft_wrapping.py +++ b/tests/models/hf/test_hf_peft_wrapping.py @@ -1,10 +1,16 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + import transformers -from peft import get_peft_model, LoraConfig +from peft import LoraConfig, get_peft_model + from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp + def test_peft_wraps(): - mistral_cfg = transformers.AutoConfig.from_pretrained('mistralai/Mistral-7B-v0.1', num_hidden_layers=2) + mistral_cfg = transformers.AutoConfig.from_pretrained( + 'mistralai/Mistral-7B-v0.1', num_hidden_layers=2) mistral = transformers.AutoModelForCausalLM.from_config(mistral_cfg) mistral = get_peft_model(mistral, LoraConfig()) prepare_hf_model_for_fsdp(mistral, 'cpu') - assert False \ No newline at end of file + assert False