diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 01917cafa22f4f..e3593af20b5ed3 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -112,6 +112,7 @@ def forward( if output_hidden_states is not None or output_attentions is not None: raise ValueError("Cannot set output_attentions or output_hidden_states for timm models") + pixel_values = pixel_values.to(self.device, self.dtype) features = self.timm_model(pixel_values, **kwargs) if not return_dict: @@ -152,6 +153,7 @@ def forward( if output_hidden_states is not None or output_attentions is not None: raise ValueError("Cannot set `output_attentions` or `output_hidden_states` for timm models") + pixel_values = pixel_values.to(self.device, self.dtype) logits = self.timm_model(pixel_values, **kwargs) loss = None