diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index 2dcb7402e8..5fddf42510 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -501,10 +501,7 @@ def eval_forward(self, batch, outputs: Optional[Any] = None): # HF encoder decoder models like T5 expect either decoder_input_ids or labels, # so we add decoder_input_ids to the batch if it is missing - from transformers import PretrainedConfig - assert isinstance(self.model.config, PretrainedConfig) - model_config: PretrainedConfig = self.model.config - if model_config.is_encoder_decoder and 'decoder_input_ids' not in batch: + if self.config.is_encoder_decoder and 'decoder_input_ids' not in batch: if hasattr(self.model, 'prepare_decoder_input_ids_from_labels'): batch['decoder_input_ids'] = self.model.prepare_decoder_input_ids_from_labels(labels=self.labels) else: @@ -561,9 +558,7 @@ def get_metadata(self): model_dir = tmp_dir / 'model' tokenizer_dir = tmp_dir / 'tokenizer' - from transformers import PretrainedConfig - assert isinstance(self.model.config, PretrainedConfig) - original_model_config: PretrainedConfig = self.model.config + original_model_config: PretrainedConfig = self.config original_model_config.save_pretrained(model_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(tokenizer_dir) @@ -645,10 +640,7 @@ def generate(self, input_ids: torch.Tensor, **kwargs): if not using_torch_2() and not self.dummy_forward_called and is_model_fsdp(self.model): with torch.no_grad(): maybe_decoder_input_ids = {} - from transformers import PretrainedConfig - assert isinstance(self.model.config, PretrainedConfig) - model_config: PretrainedConfig = self.model.config - if model_config.is_encoder_decoder: + if self.config.is_encoder_decoder: maybe_decoder_input_ids['decoder_input_ids'] = torch.tensor([[0]], dtype=torch.long, device=input_ids.device)