Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 19, 2024
1 parent beaa86c commit 03a1c57
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 20 deletions.
4 changes: 3 additions & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(new_base_model_instance, original_model.peft_config['default'])
new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config['default'])
else:
new_model_instance = type(original_model)(copied_config)

Expand Down
21 changes: 15 additions & 6 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import os
import warnings
from typing import Mapping, Union
from typing import Mapping

# required for loading a python model into composer
import transformers
Expand All @@ -19,7 +19,6 @@
LanguageCrossEntropy, LanguagePerplexity)
from composer.utils import dist
from omegaconf import DictConfig
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedModel,
PreTrainedTokenizerBase)

Expand Down Expand Up @@ -259,11 +258,21 @@ def _autoset_attn_implementation_monkeypatch(
raise ValueError(
f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}'
)

from peft import LoraConfig
peft_config = pop_config(om_model_config, 'peft_config', must_exist=False, convert=True)
peft_type = peft_config.pop('peft_type', None)
peft_config = LoraConfig(**peft_config)
peft_config = pop_config(om_model_config,
'peft_config',
must_exist=False,
convert=True)

if peft_config is not None:
peft_type = peft_config.get('peft_type', None)
if peft_type.upper() != 'LORA':
raise ValueError(f'Only LORA is supported for peft_type, but got {peft_type}.')
task_type = peft_config.get('task_type', None)
if task_type.upper() != 'CAUSAL_LM':
raise ValueError(f'Only CAUSAL_LM is supported for task_type, but got {task_type}.')
peft_config = LoraConfig(**peft_config)

composer_model = super().__init__(
model=model,
Expand Down
13 changes: 6 additions & 7 deletions llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
# which is MIT licensed

import functools
from typing import Any, Iterable, List, Optional, Union
from typing import Any, Iterable, List, Optional, Union, TYPE_CHECKING

import torch
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder

try:
if TYPE_CHECKING:
from peft import PeftModel
peft_model_type = PeftModel
except ImportError:
peft_model_type = None


# helper functions
Expand Down Expand Up @@ -135,7 +132,8 @@ def prepare_hf_model_for_fsdp(model: PreTrainedModel,
prepare_hf_causal_lm_model_for_fsdp(model, init_device)


def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, peft_model_type],
def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel,
'PeftModel'],
init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace decoder.
Expand Down Expand Up @@ -207,7 +205,8 @@ def prepare_hf_causal_lm_model_for_fsdp(model: Union[PreTrainedModel, peft_model
peft_type = model.peft_type.lower()
active_adapters = [adapter.lower() for adapter in model.active_adapters]
for name, module in model.named_modules():
if peft_type in name.lower() and any(adapter in name.lower() for adapter in active_adapters):
if peft_type in name.lower() and any(
adapter in name.lower() for adapter in active_adapters):
has_parameters = any(True for _ in module.parameters())
has_buffers = any(True for _ in module.buffers())
if has_parameters or has_buffers:
Expand Down
8 changes: 5 additions & 3 deletions llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

from __future__ import annotations

import inspect
from collections import UserDict
from typing import List, Mapping, Optional
from typing import List, Mapping, Optional, TYPE_CHECKING

import torch
import transformers
Expand All @@ -18,6 +17,9 @@

from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp

if TYPE_CHECKING:
from peft import PeftConfig

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(self,
z_loss: float = 0.0,
shift_labels: bool = False,
init_device: Optional[str] = None,
peft_config = None):
peft_config: Optional['PeftConfig'] = None):
super().__init__(model,
tokenizer,
use_logits=True,
Expand Down
12 changes: 9 additions & 3 deletions tests/models/hf/test_hf_peft_wrapping.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import transformers
from peft import get_peft_model, LoraConfig
from peft import LoraConfig, get_peft_model

from llmfoundry.models.hf.hf_fsdp import prepare_hf_model_for_fsdp


def test_peft_wraps():
mistral_cfg = transformers.AutoConfig.from_pretrained('mistralai/Mistral-7B-v0.1', num_hidden_layers=2)
mistral_cfg = transformers.AutoConfig.from_pretrained(
'mistralai/Mistral-7B-v0.1', num_hidden_layers=2)
mistral = transformers.AutoModelForCausalLM.from_config(mistral_cfg)
mistral = get_peft_model(mistral, LoraConfig())
prepare_hf_model_for_fsdp(mistral, 'cpu')
assert False
assert False

0 comments on commit 03a1c57

Please sign in to comment.