Skip to content

Commit

Permalink
clean up config usage
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Jan 25, 2024
1 parent 176fdbe commit 320ff55
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,10 +501,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):

# HF encoder decoder models like T5 expect either decoder_input_ids or labels,
# so we add decoder_input_ids to the batch if it is missing
from transformers import PretrainedConfig
assert isinstance(self.model.config, PretrainedConfig)
model_config: PretrainedConfig = self.model.config
if model_config.is_encoder_decoder and 'decoder_input_ids' not in batch:
if self.config.is_encoder_decoder and 'decoder_input_ids' not in batch:
if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'):
batch['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels=self.labels)
else:
Expand Down Expand Up @@ -561,9 +558,7 @@ def get_metadata(self):
model_dir = tmp_dir / 'model'
tokenizer_dir = tmp_dir / 'tokenizer'

from transformers import PretrainedConfig
assert isinstance(self.model.config, PretrainedConfig)
original_model_config: PretrainedConfig = self.model.config
original_model_config: PretrainedConfig = self.config
original_model_config.save_pretrained(model_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(tokenizer_dir)
Expand Down Expand Up @@ -645,10 +640,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs):
if not using_torch_2() and not self.dummy_forward_called and is_model_fsdp(self.model):
with torch.no_grad():
maybe_decoder_input_ids = {}
from transformers import PretrainedConfig
assert isinstance(self.model.config, PretrainedConfig)
model_config: PretrainedConfig = self.model.config
if model_config.is_encoder_decoder:
if self.config.is_encoder_decoder:
maybe_decoder_input_ids['decoder_input_ids'] = torch.tensor([[0]],
dtype=torch.long,
device=input_ids.device)
Expand Down

0 comments on commit 320ff55

Please sign in to comment.