Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timm_wrapper support to AutoFeatureExtractor #35764

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Factral
Copy link

@Factral Factral commented Jan 18, 2025

What does this PR do?

A few days ago, the PR that adds timm_wrapper was merged #34564 blog post , enabling the use of timm models directly with Hugging Face interfaces, especially the Auto* ones. However, currently the AutoFeatureExtractor interface doesn't work with these models. This PR addresses that gap.

This PR adds timm_wrapper compatibility to AutoFeatureExtractor.from_pretrained(), enabling it to work with fine-tuned/trained timm model checkpoints.

Currently, when using a checkpoint from a trained/fine-tuned timm model (e.g., using examples/pytorch/image-classification/run_image_classification.py), AutoFeatureExtractor.from_pretrained() fails because timm_wrapper is not included in the interface.

While there's a warning about missing preprocessor_config.json in checkpoints, users can manually add it to their checkpoint following examples like https://huggingface.co/Factral/vit_large-model/blob/main/preprocessor_config.json. This PR ensures AutoFeatureExtractor works properly when this file is present.

Changes

  • Added timm_wrapper to AutoFeatureExtractor interface
  • Enables compatibility with timm model checkpoints when preprocessor_config.json is present
  • Added is_timm kwarg in from_dict function

Before submitting

  • Read contributor guidelines
  • Updated documentation to reflect changes
  • Added necessary tests for timm_wrapper functionality

Who can review?

@amyeroberts @qubvel - as this relates to vision models and timm integration

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Factral, thanks for submitting the PR!

I would recommend using AutoModel + AutoProcessor to get features from any Timm model. It works without adding preprocessing_config.json. Othrewise, we need to come up with the scheme to import FeatureExtractor the same way, without adding preprocessing_config.json to the repo on Hub, because preprocessing config is stored in config.json for timm models (please see how this was enabled for other AutoProcessors in original PR #34564)

@qubvel
Copy link
Member

qubvel commented Jan 20, 2025

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel

checkpoint = "timm/resnet18.a1_in1k"

model = AutoModel.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

# load your image here
image = Image.new("RGB", (224, 224), (255, 0, 0))

inputs = processor(image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

for k, v in outputs.items():
    print(k, v.shape)

# last_hidden_state torch.Size([1, 512, 7, 7])
# pooler_output torch.Size([1, 512])

@qubvel qubvel added the Vision label Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants