Skip to content

Commit

Permalink
simplify attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 23, 2024
1 parent c26287d commit a0e217f
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

try:
from peft import PeftModel, get_peft_model
_peft_installed = True
peft_installed = True
except:
_peft_installed = False
peft_installed = False

if TYPE_CHECKING:
import transformers
Expand Down Expand Up @@ -56,7 +56,9 @@ class HuggingFaceModel(ComposerModel):
eval_metrics (list[Metric], optional): list of torchmetrics to compute on the eval_dataloader, or be accessible to :class:`Evaluator`s. Default: ``None``.
shift_labels (bool, optional): If True, the batch's labels will be shifted before being used to calculate metrics. This should be set to true for CausalLM models and false otherwise. If not specified, `shift_labels` will be set automatically based on the model class name. Default: ``None``.
allow_embedding_resizing (bool, optional): If True, the model's embeddings will be automatically resized when they are smaller than the tokenizer vocab size. Default: ``False``.
peft_config (PeftConfig, optional): Optional PEFT config to apply to the model. If provided, the model will be converted to a PEFT model. Only LoRA is currently supported.
peft_filter_state_dict_trainable (bool, optional): If True _and_ PEFT is active, the state dict will only contain the PEFT weights, not the frozen base model weights.
.. note:: To ensure correct behavior, set `shift_labels` manually if using a custom model (i.e., if `model` is not
an instance of a registered 🤗 Transformers class).
.. warning:: This wrapper is designed to work with 🤗 datasets that define a `labels` column.
Expand Down Expand Up @@ -95,16 +97,12 @@ def __init__(self,
super().__init__()
self.model = model
self.config: PretrainedConfig = model.config
self.model_forward_args = inspect.getfullargspec(self.model.forward).args

if _peft_installed and self.model_forward_args == ['self']:
self.model_forward_args = inspect.signature(self.model.forward).parameters.keys()
if peft_installed:
from peft import PeftModel
if isinstance(self.model, PeftModel):
self.model_forward_args = inspect.getfullargspec(self.model.base_model.model.forward).args

# inspect.getfullargspec HuggingFace quantized model could not return args correctly
if not self.model_forward_args:
self.model_forward_args = inspect.signature(self.model.forward).parameters.keys()
self.model_forward_args = inspect.signature(self.model.base_model.model.forward).parameters.keys()

if not self.model_forward_args:
raise ValueError('Could not determine the forward arguments of the model. Please open a GitHub issue.')
Expand All @@ -113,7 +111,7 @@ def __init__(self,

self.peft_filter_state_dict_trainable = peft_filter_state_dict_trainable
if peft_config is not None:
if not _peft_installed:
if not peft_installed:
raise MissingConditionalImportError(extra_deps_group='peft',
conda_package='peft',
conda_channel='conda-forge')
Expand Down Expand Up @@ -193,7 +191,7 @@ def __init__(self,
log.info(f'PEFT model created. {self.model}')

self.model_is_peft = False
if _peft_installed:
if peft_installed:
from peft import PeftModel
self.using_peft = isinstance(self.model, PeftModel)

Expand Down Expand Up @@ -576,7 +574,7 @@ def get_metadata(self):
}

# Also save PEFT config if the model is a peft model
if _peft_installed:
if peft_installed:
from peft import PeftModel
if isinstance(self.model, PeftModel):
active_adapter = self.model.active_adapter
Expand Down Expand Up @@ -678,7 +676,7 @@ def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftMod
conda_package='transformers',
conda_channel='conda-forge') from e

if _peft_installed and isinstance(model, PeftModel):
if peft_installed and isinstance(model, PeftModel):
model_to_check = model.base_model.model
else:
model_to_check = model
Expand Down

0 comments on commit a0e217f

Please sign in to comment.