Skip to content

Commit

Permalink
refactor underlying model get
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 26, 2024
1 parent 320ff55 commit 67e3cb2
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,8 @@ def __init__(self,
self.model = model
self.config: PretrainedConfig = model.config

self.model_forward_args = inspect.signature(self.model.forward).parameters.keys()
if peft_installed:
from peft import PeftModel
if isinstance(self.model, PeftModel):
self.model_forward_args = inspect.signature(self.model.base_model.model.forward).parameters.keys()
model_for_forward = maybe_get_underlying_model(model)
self.model_forward_args = inspect.signature(model_for_forward.forward).parameters.keys()

if not self.model_forward_args:
raise ValueError('Could not determine the forward arguments of the model. Please open a GitHub issue.')
Expand Down Expand Up @@ -662,6 +659,22 @@ def generate(self, input_ids: torch.Tensor, **kwargs):
return self.model.generate(input_ids=input_ids, pad_token_id=pad_token_id, **kwargs)


def maybe_get_underlying_model(
model: Union[transformers.PreTrainedModel, 'PeftModel']) -> Union[transformers.PreTrainedModel, 'PeftModel']:
"""Get the underlying PreTrainedModel from a model if it is a PEFT model
Args:
model (Union[transformers.PreTrainedModel, 'PeftModel']): The model to get the underlying model from
Returns:
Union[transformers.PreTrainedModel]: The underlying transformers model
"""
if peft_installed and isinstance(model, PeftModel):
return model.base_model.model
else:
return model


def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftModel']) -> bool:
"""Return True if model class is either a registered 🤗 Causal LM or a subclass of one"""
try:
Expand All @@ -671,10 +684,7 @@ def _is_registered_causal_lm(model: Union[transformers.PreTrainedModel, 'PeftMod
conda_package='transformers',
conda_channel='conda-forge') from e

if peft_installed and isinstance(model, PeftModel):
model_to_check = model.base_model.model
else:
model_to_check = model
model_to_check = maybe_get_underlying_model(model)

# This try/except is needed until https://github.com/huggingface/transformers/issues/26778
# is resolved in a release. This means that this attempt to automatically detect causal LMs
Expand Down

0 comments on commit 67e3cb2

Please sign in to comment.