Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TimmWrapper #34564

Merged
merged 92 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
39c42db
Add files
amyeroberts Sep 10, 2024
4b98375
Init
amyeroberts Sep 20, 2024
aa494f1
Add TimmWrapperModel
amyeroberts Sep 24, 2024
44d123e
Fix up
amyeroberts Sep 24, 2024
4b35ae2
Some fixes
amyeroberts Sep 27, 2024
2b5db8f
Fix up
amyeroberts Sep 27, 2024
b07a5c9
Remove old file
amyeroberts Sep 27, 2024
e3a88b6
Sort out import orders
amyeroberts Sep 27, 2024
baffbe2
Fix some model loading
amyeroberts Sep 27, 2024
6fc50cf
Compatible with pipeline and trainer
amyeroberts Sep 30, 2024
ed00b41
Fix up
amyeroberts Sep 30, 2024
50b507b
Delete test_timm_model_1/config.json
amyeroberts Sep 30, 2024
de87e54
Remove accidentally commited files
amyeroberts Sep 30, 2024
a907061
Delete src/transformers/models/modeling_timm_wrapper.py
amyeroberts Oct 1, 2024
1f55841
Remove empty imports; fix transformations applied
amyeroberts Oct 1, 2024
23b38af
Tidy up
amyeroberts Oct 1, 2024
2f0aee1
Add image classifcation model to special cases
amyeroberts Oct 1, 2024
bff2e98
Create pretrained model; enable device_map='auto'
amyeroberts Oct 1, 2024
0c80253
Enable most tests; fix init order
amyeroberts Oct 2, 2024
a32de6a
Sort imports
amyeroberts Oct 2, 2024
59f55d1
[run-slow] timm_wrapper
amyeroberts Oct 2, 2024
666419f
Pass num_classes into timm.create_model
qubvel Oct 11, 2024
5540d32
Remove train transforms from image processor
qubvel Nov 1, 2024
88f737c
Update timm creation with pretrained=False
qubvel Nov 1, 2024
496d38d
Fix gamma/beta issue for timm models
qubvel Nov 1, 2024
4cfa51f
Fixing gamma and beta renaming for timm models
qubvel Nov 1, 2024
3891975
Simplify config and model creation
qubvel Nov 1, 2024
f878ab4
Remove attn_implementation diff
qubvel Nov 1, 2024
b1f145f
Fixup
qubvel Nov 1, 2024
7d87a95
Docstrings
qubvel Nov 1, 2024
ae0425d
Fix warning msg text according to test case
qubvel Nov 1, 2024
ec9eade
Fix device_map auto
qubvel Nov 4, 2024
9767048
Set dtype and device for pixel_values in forward
qubvel Nov 4, 2024
d5478f6
Enable output hidden states
qubvel Nov 4, 2024
b61048b
Enable tests for hidden_states and model parallel
qubvel Nov 4, 2024
72236b5
Remove default scriptable arg
qubvel Nov 4, 2024
e290f22
Refactor inner model
qubvel Nov 4, 2024
0a08a1f
Update timm version
qubvel Nov 4, 2024
1faf9ec
Fix _find_mismatched_keys function
qubvel Nov 5, 2024
addbac8
Change inheritance for Classification model (fix weights loading with…
qubvel Nov 5, 2024
8ba4592
Minor bugfix
qubvel Nov 19, 2024
2b993a7
Disable save pretrained for image processor
qubvel Nov 19, 2024
0b00cef
Rename hook method for loaded keys correction
qubvel Nov 19, 2024
485fe7a
Rename state dict keys on save, remove `timm_model` prefix, make chec…
qubvel Nov 19, 2024
dede4e4
Managing num_labels <-> num_classes attributes
qubvel Nov 19, 2024
1ec132c
Enable loading checkpoints in Trainer to resume training
qubvel Nov 19, 2024
d96b257
Update error message for output_hidden_states
qubvel Nov 20, 2024
6b1b621
Add output hidden states test
qubvel Nov 20, 2024
bb7d465
Decouple base and classification models
qubvel Nov 20, 2024
4390213
Add more test cases
qubvel Nov 20, 2024
4233850
Add save-load-to-timm test
qubvel Nov 20, 2024
ef96b6d
Merge branch 'main' into timm-wrapper
qubvel Nov 20, 2024
c3bb39a
Fix test name
qubvel Nov 20, 2024
5444a6c
Fixup
qubvel Nov 20, 2024
06039d0
Add do_pooling
qubvel Nov 27, 2024
0b78735
Add test for do_pooling
qubvel Nov 27, 2024
5bb2950
Fix doc
qubvel Nov 27, 2024
38b9423
Add tests for TimmWrapperModel
qubvel Nov 27, 2024
c1cc1fa
Add validation for `num_classes=0` in timm config + test for DINO che…
qubvel Nov 27, 2024
3acafad
Adjust atol for test
qubvel Nov 27, 2024
8e8f41b
Fix docs
qubvel Nov 27, 2024
c5c288f
dev-ci
qubvel Nov 28, 2024
a4ae76f
dev-ci
qubvel Nov 28, 2024
464874f
Add tests for image processor
qubvel Nov 28, 2024
02fbd58
Update docs
qubvel Nov 28, 2024
eb9a66f
Update init to new format
qubvel Nov 28, 2024
e08fe70
Update docs in configuration
qubvel Nov 28, 2024
7734804
Fix some docs in image processor
qubvel Nov 28, 2024
051acee
Improve docs for modeling
qubvel Nov 28, 2024
185fab8
fix for is_timm_checkpoint
qubvel Nov 28, 2024
bca3279
Update code examples
qubvel Nov 28, 2024
672cc6d
Fix header
qubvel Nov 28, 2024
983b9b2
Fix typehint
qubvel Nov 28, 2024
9c128ce
Increase tolerance a bit
qubvel Nov 28, 2024
f2dba79
Merge branch 'main' into timm-wrapper
qubvel Nov 29, 2024
9bc887b
Fix Path
qubvel Nov 29, 2024
f92216f
Fixing model parallel tests
qubvel Dec 2, 2024
42278a7
Disable "parallel" tests
qubvel Dec 2, 2024
6b3ba3b
Add comment for metadata
qubvel Dec 2, 2024
ff6efde
Refactor AutoImageProcessor for timm wrapper loading
qubvel Dec 2, 2024
a476610
Remove custom test_model_outputs_equivalence
qubvel Dec 2, 2024
327095a
Add require_timm decorator
qubvel Dec 2, 2024
7e0d2c6
Fix comment
qubvel Dec 2, 2024
a10bc0d
Make image processor work with older timm versions and tensor input
qubvel Dec 3, 2024
a87fbf6
Save config instead of whole model in image processor tests
qubvel Dec 3, 2024
90c1c88
Add docstring for `image_processor_filename`
qubvel Dec 3, 2024
fd7b646
Sanitize kwargs for timm image processor
qubvel Dec 3, 2024
cdd3811
Fix doc style
qubvel Dec 3, 2024
3d1a76e
Update check for tensor input
qubvel Dec 3, 2024
cc0a330
Merge branch 'main' into timm-wrapper
qubvel Dec 5, 2024
4edfe90
Update normalize
qubvel Dec 11, 2024
8b82e2b
Remove _load_timm_model function
qubvel Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@
title: Swin2SR
- local: model_doc/table-transformer
title: Table Transformer
- local: model_doc/timm_wrapper
title: Timm Wrapper
- local: model_doc/upernet
title: UperNet
- local: model_doc/van
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ Flax), PyTorch, and/or TensorFlow.
| [TAPEX](model_doc/tapex) | ✅ | ✅ | ✅ |
| [Time Series Transformer](model_doc/time_series_transformer) | ✅ | ❌ | ❌ |
| [TimeSformer](model_doc/timesformer) | ✅ | ❌ | ❌ |
| [TimmWrapperModel](model_doc/timm_wrapper) | ✅ | ❌ | ❌ |
| [Trajectory Transformer](model_doc/trajectory_transformer) | ✅ | ❌ | ❌ |
| [Transformer-XL](model_doc/transfo-xl) | ✅ | ✅ | ❌ |
| [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ |
Expand Down
61 changes: 61 additions & 0 deletions docs/source/en/model_doc/timm_wrapper.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# TimmWrapper

## Overview

Helper class to enable loading timm models to be used with the transformers library and its autoclasses.

```python
import torch
from urllib.request import urlopen
from PIL import Image
from transformers import AutoConfig, AutoModelForImageClassification, AutoImageProcessor

checkpoint = "timm/resnet50.a1_in1k"
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

image_processor = AutoImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(img, return_tensors="pt")
model = AutoModelForImageClassification.from_pretrained(checkpoint)

with torch.no_grad():
logits = model(**inputs).logits

top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
```

## TimmWrapperConfig

[[autodoc]] TimmWrapperConfig

## TimmWrapperImageProcessor

[[autodoc]] TimmWrapperImageProcessor
- preprocess

## TimmWrapperModel

[[autodoc]] TimmWrapperModel
- forward

## TimmWrapperForImageClassification

[[autodoc]] TimmWrapperForImageClassification
- forward
53 changes: 29 additions & 24 deletions examples/pytorch/image-classification/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AutoImageProcessor,
AutoModelForImageClassification,
HfArgumentParser,
TimmWrapperImageProcessor,
Trainer,
TrainingArguments,
set_seed,
Expand Down Expand Up @@ -329,31 +330,35 @@ def compute_metrics(p):
)

# Define torchvision transforms to be applied to each image.
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
if isinstance(image_processor, TimmWrapperImageProcessor):
_train_transforms = image_processor.train_transforms
_val_transforms = image_processor.val_transforms
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do something explicit for this!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you meant here, but I made it a bit more readable in 4edfe90 IMO. It's actually not related to the timm wrapper, it's the same in the original code.

_train_transforms = Compose(
[
RandomResizedCrop(size),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize(size),
CenterCrop(size),
ToTensor(),
normalize,
]
)

def train_transforms(example_batch):
"""Apply _train_transforms across a batch."""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
"tf2onnx",
"timeout-decorator",
"tiktoken",
"timm<=0.9.16",
"timm<=1.0.11",
"tokenizers>=0.20,<0.21",
"torch",
"torchaudio",
Expand Down
30 changes: 30 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@
"models.time_series_transformer": ["TimeSeriesTransformerConfig"],
"models.timesformer": ["TimesformerConfig"],
"models.timm_backbone": ["TimmBackboneConfig"],
"models.timm_wrapper": ["TimmWrapperConfig"],
"models.trocr": [
"TrOCRConfig",
"TrOCRProcessor",
Expand Down Expand Up @@ -1259,6 +1260,18 @@
_import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
_import_structure["models.vit"].append("ViTImageProcessorFast")

try:
if not is_torchvision_available() and not is_timm_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_timm_and_torchvision_objects

_import_structure["utils.dummy_timm_and_torchvision_objects"] = [
name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_")
]
else:
_import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"])

# PyTorch-backed objects
try:
if not is_torch_available():
Expand Down Expand Up @@ -3494,6 +3507,9 @@
]
)
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
_import_structure["models.timm_wrapper"].extend(
["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"]
)
_import_structure["models.trocr"].extend(
[
"TrOCRForCausalLM",
Expand Down Expand Up @@ -5689,6 +5705,7 @@
TimesformerConfig,
)
from .models.timm_backbone import TimmBackboneConfig
from .models.timm_wrapper import TimmWrapperConfig
from .models.trocr import (
TrOCRConfig,
TrOCRProcessor,
Expand Down Expand Up @@ -6180,6 +6197,14 @@
from .image_processing_utils_fast import BaseImageProcessorFast
from .models.vit import ViTImageProcessorFast

try:
if not is_torchvision_available() and not is_timm_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_timm_and_torchvision_objects import *
else:
from .models.timm_wrapper import TimmWrapperImageProcessor

# Modeling
try:
if not is_torch_available():
Expand Down Expand Up @@ -7971,6 +7996,11 @@
TimesformerPreTrainedModel,
)
from .models.timm_backbone import TimmBackbone
from .models.timm_wrapper import (
TimmWrapperForImageClassification,
TimmWrapperModel,
TimmWrapperPreTrainedModel,
)
from .models.trocr import (
TrOCRForCausalLM,
TrOCRPreTrainedModel,
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
download_url,
extract_commit_hash,
is_remote_url,
is_timm_checkpoint,
is_torch_available,
logging,
)
Expand Down Expand Up @@ -548,7 +549,6 @@ def from_pretrained(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)

@classmethod
Expand Down Expand Up @@ -686,6 +686,11 @@ def _get_config_dict(
config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
config_dict["custom_pipelines"], pretrained_model_name_or_path
)

if "model_type" not in config_dict and is_timm_checkpoint(resolved_config_file):
# timm models are not saved with the model_type in the config file
config_dict["model_type"] = "timm_wrapper"

return config_dict, kwargs

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator",
"tiktoken": "tiktoken",
"timm": "timm<=0.9.16",
"timm": "timm<=1.0.11",
"tokenizers": "tokenizers>=0.20,<0.21",
"torch": "torch",
"torchaudio": "torchaudio",
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/image_processing_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def get_image_processor_dict(
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
Copy link
Member Author

@qubvel qubvel Nov 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timm models store the image processor configuration in config.json instead of preprocessor_config.json. This is why we check if another name is provided.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand - since it's a new kwarg that is does not seem to have an equivalent in hub methods (like use_auth_token or revision) I'd add a small docstring to advertise it


from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
Expand All @@ -321,15 +322,15 @@ def get_image_processor_dict(
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
image_processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
if os.path.isfile(pretrained_model_name_or_path):
resolved_image_processor_file = pretrained_model_name_or_path
is_local = True
elif is_remote_url(pretrained_model_name_or_path):
image_processor_file = pretrained_model_name_or_path
resolved_image_processor_file = download_url(pretrained_model_name_or_path)
else:
image_processor_file = IMAGE_PROCESSOR_NAME
image_processor_file = image_processor_filename
try:
# Load from local folder or from cache or download from model Hub and cache
resolved_image_processor_file = cached_file(
Expand All @@ -355,7 +356,7 @@ def get_image_processor_dict(
f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
" it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a {IMAGE_PROCESSOR_NAME} file"
f" directory containing a {image_processor_filename} file"
)

try:
Expand Down
Loading