Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TimmWrapper #34564

Merged
merged 92 commits into from
Dec 11, 2024
Merged

Add TimmWrapper #34564

merged 92 commits into from
Dec 11, 2024

Conversation

qubvel
Copy link
Member

@qubvel qubvel commented Nov 1, 2024

What does this PR do?

Adds a TimmWrapper set of classes such that timm models can be loaded in as transformer models into the library.

Continue of

General Usage

import torch
from urllib.request import urlopen
from PIL import Image
from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor

checkpoint = "timm/resnet50.a1_in1k"
img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

image_processor = AutoImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(img, return_tensors="pt")
model = AutoModelForImageClassification.from_pretrained(checkpoint)

with torch.no_grad():
    logits = model(**inputs).logits

top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)

Pipeline

Timm models can now be used in the image classification (if a classification model) and image feature extraction pipelines

import torch
from urllib.request import urlopen
from PIL import Image

from transformers import pipeline

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
pipe = pipeline("image-classification", model="timm/resnet18.a1_in1k")
print(pipe(img))

Trainer

Timm models can now be loaded and trained with the trainer class.

Example model trained with the trainer running the script command below:
https://huggingface.co/qubvel-hf/vit-base-beans

python run_image_classification.py \                
    --dataset_name beans \
    --output_dir ./beans_outputs/ \
    --remove_unused_columns False \
    --label_column_name labels \
    --do_train \
    --do_eval \
    --push_to_hub \
    --push_to_hub_model_id vit-base-beans \
    --learning_rate 2e-5 \
    --num_train_epochs 5 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_strategy epoch \
    --save_strategy epoch \
    --load_best_model_at_end True \
    --save_total_limit 3 \
    --seed 1337 \
    --model_name_or_path timm/resnet18.a1_in1k \
    --ignore_mismatched_sizes

Other features enabled

  • Device map:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, device_map="auto")
  • Torch dtype:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, torch_dtype="bfloat16")
  • Quantization:
model = TimmWrapperForImageClassification.from_pretrained(checkpoint, load_in_4bit=True)
  • Intermediate hidden states: output_hidden_states=True or output_hidden_states=[1, 2, 3] (to select specific hidden states)
model = TimmWrapperForImageClassification.from_pretrained(checkpoint)
output = model(**intpus, output_hidden_states=True)
  • Transformers TImmWrapper checkpoints are compatible with timm:
model = timm.create_model("hf-hub:qubvel-hf/vit-base-beans", pretrained=True)

TODO

  • Gamma/beta renaming issue
  • Update timm in CI 0.9.6 -> 1.0.11 to enable output_hidden_states tests
    • CI for slow-run takes longer to update images
  • Weights are loaded by transformers instead of timm, which architectures are affected?
  • Tests for image processor

@qubvel qubvel marked this pull request as draft November 1, 2024 15:52
@qubvel qubvel marked this pull request as ready for review December 2, 2024 10:58
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is starting to look nice!

>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice that there is no kwargs or whatever to load the model

@@ -502,7 +502,7 @@ def load_state_dict(
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok with this

Comment on lines -636 to -684
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should eventually be completely removed, cc @ArthurZucker

Comment on lines 3910 to 3911
if metadata is None:
pass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment saying that in case of no metadata, it's seen as a pytorch checkpoint

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added 7e0d2c6

Comment on lines 417 to 420
default_image_processor_filename = (
"config.json" if is_timm_checkpoint(pretrained_model_name_or_path) else IMAGE_PROCESSOR_NAME
)
kwargs["image_processor_filename"] = kwargs.get("image_processor_filename", default_image_processor_filename)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use CONFIG_NAME instead

@@ -867,3 +869,48 @@ class LossKwargs(TypedDict, total=False):
"""

num_items_in_batch: Optional[int]


def is_timm_hub_checkpoint(pretrained_model_name_or_path: str) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the method's objective, I would have it accept only a pretrained_model_name

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function removed with ff6efde (see comment below)

Comment on lines 879 to 881
if os.path.isfile(pretrained_model_name_or_path) or os.path.isdir(pretrained_model_name_or_path):
return False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I'd therefore remove this

Copy link
Member Author

@qubvel qubvel Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

function removed with ff6efde (see comment below)

Comment on lines 882 to 884
return pretrained_model_name_or_path.startswith("hf-hub:timm/") or pretrained_model_name_or_path.startswith(
"timm/"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good as long as we don't expect community checkpoints, but there are community checkpoints already, for example see the following: https://huggingface.co/prov-gigapath/prov-gigapath

I think we'll need a more robust check here

Copy link
Member Author

@qubvel qubvel Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is a weak assumption. This function is only needed in the image processors auto class to load it from the config. In AutoImageProcessor.from_pretrained the only information we have is the model name, so the only way to make a robust check is to load files from the Hub.

I removed this function entirely and instead made a fallback for loading the timm image processor dict in the auto class of the image processor. To avoid loading the config for every model, I did it in the following way:

  1. Try to load the image processor config as usual - most of the models will be fine, and we won't have any overhead here.
  2. In case of an exception, try loading config.json and check if it's a timm checkpoint.

See ff6efde for details. I documented it in the code, let me know if you have doubts about this approach.

Works with

image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks better to me!

def test_model_is_small(self):
pass

# Overriding as output_attentions is not supported by TimmWrapper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, removed in a476610


@require_torch
@require_vision
class TimmWrapperModelIntegrationTest(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should require timm as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 327095a

@LysandreJik
Copy link
Member

Thanks, this looks good! cc @molbap can you give the processor code a quick look just to double check?

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, took a quick look at the processor and ran it, found some stufff which I commented! Also looked at the whole PR, real nice work!

@@ -295,6 +295,7 @@ def get_image_processor_dict(
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand - since it's a new kwarg that is does not seem to have an equivalent in hub methods (like use_auth_token or revision) I'd add a small docstring to advertise it

Comment on lines +428 to +444
try:
# Main path for all transformers models and local TimmWrapper checkpoints
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)
except Exception as initial_exception:
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json`
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to
# load `config.json` and if it fails with some error, we raise the initial exception.
try:
config_dict, _ = ImageProcessingMixin.get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs
)
except Exception:
raise initial_exception

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A remark here - that's true for any processor with a preprocessor_config.json, but would it make sense to sanitize the inputs a bit? On community checkpoints, there's extra keys that are unused, for instance model_args that contains duplicated information.
(I'm asking because we're already using a try/catch pattern here so that allows some branching)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I got it, can you provide more details, please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically I I load https://huggingface.co/prov-gigapath/prov-gigapath/blob/main/config.json, I'll end up with an ImageProcessor object that has keys I can't make much use of in transformers, such as model_args, so I wondered if it made sense to filter the contents of that config.json to an expected schema

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it makes sense now. I made it on the image-processor init level, see fd7b646

Is it what you had in mind?

>>> from transformers import AutoImageProcessor, AutoConfig

>>> image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")
>>> print(image_processor)

TimmWrapperImageProcessor {
  "architecture": "vit_giant_patch14_dinov2",
  "data_config": {
    "crop_mode": "center",
    "crop_pct": 1.0,
    "input_size": [
      3,
      224,
      224
    ],
    "interpolation": "bicubic",
    "mean": [
      0.485,
      0.456,
      0.406
    ],
    "std": [
      0.229,
      0.224,
      0.225
    ]
  },
  "image_processor_type": "TimmWrapperImageProcessor"
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, exactly! that looks good

Comment on lines +96 to +99
if isinstance(images, torch.Tensor):
images = self.val_transforms(images)
# Add batch dimension if a single image
images = images.unsqueeze(0) if images.ndim == 3 else images
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, if val_transforms is for instance

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

Then the ToTensor op will fail, since F.to_tensor does expect a PIL image IIRC

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, since timm>=1.0.8 its MaybeToTensor and works fine.. but on timm<1.0.8 it's indeed raising an error

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

I can add something like

if timm < 1.0.8 and isinstance(images, torch.Tensor):
    images = images.cpu().numpy()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in a10bc0d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qubvel version checks tend to be brittle, can you do a hasattr on the MaybeToTensor class existing? I think that should line up with its use in the transforms?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rwightman, thanks for the note. I added a check for the class name, it looks not that elegant but I hope it's more robust 3d1a76e

@qubvel
Copy link
Member Author

qubvel commented Dec 4, 2024

@LysandreJik @molbap @rwightman Thanks for the reviews! I believe all comments have been addressed. Do you have anything else in mind? It would be nice to move it forward.

@LysandreJik
Copy link
Member

Ok awesome! At this point just a second quick look from @ArthurZucker and we're good

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice ! This is kind of a perfect integration with Auto API, congrats!

Comment on lines 341 to 345
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do something explicit for this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you meant here, but I made it a bit more readable in 4edfe90 IMO. It's actually not related to the timm wrapper, it's the same in the original code.

Comment on lines -636 to -684
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep there is #33192 kinda related

src/transformers/models/auto/configuration_auto.py Outdated Show resolved Hide resolved
@qubvel qubvel merged commit 5fcf628 into huggingface:main Dec 11, 2024
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants