Skip to content

Commit

Permalink
Set dtype and device for pixel_values in forward
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Nov 4, 2024
1 parent ec9eade commit 9767048
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def forward(
if output_hidden_states is not None or output_attentions is not None:
raise ValueError("Cannot set output_attentions or output_hidden_states for timm models")

pixel_values = pixel_values.to(self.device, self.dtype)
features = self.timm_model(pixel_values, **kwargs)

if not return_dict:
Expand Down Expand Up @@ -152,6 +153,7 @@ def forward(
if output_hidden_states is not None or output_attentions is not None:
raise ValueError("Cannot set `output_attentions` or `output_hidden_states` for timm models")

pixel_values = pixel_values.to(self.device, self.dtype)
logits = self.timm_model(pixel_values, **kwargs)

loss = None
Expand Down

0 comments on commit 9767048

Please sign in to comment.