-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TimmWrapper #34564
Add TimmWrapper #34564
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is starting to look nice!
>>> # Load model and image processor | ||
>>> checkpoint = "timm/resnet50.a1_in1k" | ||
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint) | ||
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice that there is no kwargs or whatever to load the model
@@ -502,7 +502,7 @@ def load_state_dict( | |||
# Check format of the archive | |||
with safe_open(checkpoint_file, framework="pt") as f: | |||
metadata = f.metadata() | |||
if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok with this
# Convert old format to new format if needed from a PyTorch state_dict | ||
old_keys = [] | ||
new_keys = [] | ||
renamed_keys = {} | ||
renamed_gamma = {} | ||
renamed_beta = {} | ||
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " | ||
for key in state_dict.keys(): | ||
new_key = None | ||
if "gamma" in key: | ||
# We add only the first key as an example | ||
new_key = key.replace("gamma", "weight") | ||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma | ||
if "beta" in key: | ||
# We add only the first key as an example | ||
new_key = key.replace("beta", "bias") | ||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta | ||
if new_key: | ||
old_keys.append(key) | ||
new_keys.append(new_key) | ||
renamed_keys = {**renamed_gamma, **renamed_beta} | ||
if renamed_keys: | ||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||
for old_key, new_key in renamed_keys.items(): | ||
warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||
logger.info_once(warning_msg) | ||
for old_key, new_key in zip(old_keys, new_keys): | ||
state_dict[new_key] = state_dict.pop(old_key) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should eventually be completely removed, cc @ArthurZucker
src/transformers/modeling_utils.py
Outdated
if metadata is None: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a comment saying that in case of no metadata, it's seen as a pytorch checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added 7e0d2c6
default_image_processor_filename = ( | ||
"config.json" if is_timm_checkpoint(pretrained_model_name_or_path) else IMAGE_PROCESSOR_NAME | ||
) | ||
kwargs["image_processor_filename"] = kwargs.get("image_processor_filename", default_image_processor_filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use CONFIG_NAME
instead
src/transformers/utils/generic.py
Outdated
@@ -867,3 +869,48 @@ class LossKwargs(TypedDict, total=False): | |||
""" | |||
|
|||
num_items_in_batch: Optional[int] | |||
|
|||
|
|||
def is_timm_hub_checkpoint(pretrained_model_name_or_path: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the method's objective, I would have it accept only a pretrained_model_name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function removed with ff6efde (see comment below)
src/transformers/utils/generic.py
Outdated
if os.path.isfile(pretrained_model_name_or_path) or os.path.isdir(pretrained_model_name_or_path): | ||
return False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and I'd therefore remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function removed with ff6efde (see comment below)
src/transformers/utils/generic.py
Outdated
return pretrained_model_name_or_path.startswith("hf-hub:timm/") or pretrained_model_name_or_path.startswith( | ||
"timm/" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good as long as we don't expect community checkpoints, but there are community checkpoints already, for example see the following: https://huggingface.co/prov-gigapath/prov-gigapath
I think we'll need a more robust check here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this is a weak assumption. This function is only needed in the image processors auto class to load it from the config. In AutoImageProcessor.from_pretrained
the only information we have is the model name, so the only way to make a robust check is to load files from the Hub.
I removed this function entirely and instead made a fallback for loading the timm image processor dict in the auto class of the image processor. To avoid loading the config for every model, I did it in the following way:
- Try to load the image processor config as usual - most of the models will be fine, and we won't have any overhead here.
- In case of an exception, try loading config.json and check if it's a timm checkpoint.
See ff6efde for details. I documented it in the code, let me know if you have doubts about this approach.
Works with
image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks better to me!
def test_model_is_small(self): | ||
pass | ||
|
||
# Overriding as output_attentions is not supported by TimmWrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, removed in a476610
|
||
@require_torch | ||
@require_vision | ||
class TimmWrapperModelIntegrationTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should require timm as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in 327095a
Thanks, this looks good! cc @molbap can you give the processor code a quick look just to double check? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, took a quick look at the processor and ran it, found some stufff which I commented! Also looked at the whole PR, real nice work!
@@ -295,6 +295,7 @@ def get_image_processor_dict( | |||
local_files_only = kwargs.pop("local_files_only", False) | |||
revision = kwargs.pop("revision", None) | |||
subfolder = kwargs.pop("subfolder", "") | |||
image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand - since it's a new kwarg that is does not seem to have an equivalent in hub methods (like use_auth_token
or revision
) I'd add a small docstring to advertise it
try: | ||
# Main path for all transformers models and local TimmWrapper checkpoints | ||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict( | ||
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs | ||
) | ||
except Exception as initial_exception: | ||
# Fallback path for Hub TimmWrapper checkpoints. Timm models' image processing is saved in `config.json` | ||
# instead of `preprocessor_config.json`. Because this is an Auto class and we don't have any information | ||
# except the model name, the only way to check if a remote checkpoint is a timm model is to try to | ||
# load `config.json` and if it fails with some error, we raise the initial exception. | ||
try: | ||
config_dict, _ = ImageProcessingMixin.get_image_processor_dict( | ||
pretrained_model_name_or_path, image_processor_filename=CONFIG_NAME, **kwargs | ||
) | ||
except Exception: | ||
raise initial_exception | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A remark here - that's true for any processor with a preprocessor_config.json
, but would it make sense to sanitize the inputs a bit? On community checkpoints, there's extra keys that are unused, for instance model_args
that contains duplicated information.
(I'm asking because we're already using a try/catch pattern here so that allows some branching)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I got it, can you provide more details, please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically I I load https://huggingface.co/prov-gigapath/prov-gigapath/blob/main/config.json, I'll end up with an ImageProcessor
object that has keys I can't make much use of in transformers, such as model_args
, so I wondered if it made sense to filter the contents of that config.json
to an expected schema
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, it makes sense now. I made it on the image-processor init level, see fd7b646
Is it what you had in mind?
>>> from transformers import AutoImageProcessor, AutoConfig
>>> image_processor = AutoImageProcessor.from_pretrained("prov-gigapath/prov-gigapath")
>>> print(image_processor)
TimmWrapperImageProcessor {
"architecture": "vit_giant_patch14_dinov2",
"data_config": {
"crop_mode": "center",
"crop_pct": 1.0,
"input_size": [
3,
224,
224
],
"interpolation": "bicubic",
"mean": [
0.485,
0.456,
0.406
],
"std": [
0.229,
0.224,
0.225
]
},
"image_processor_type": "TimmWrapperImageProcessor"
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, exactly! that looks good
if isinstance(images, torch.Tensor): | ||
images = self.val_transforms(images) | ||
# Add batch dimension if a single image | ||
images = images.unsqueeze(0) if images.ndim == 3 else images |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here, if val_transforms
is for instance
Compose(
Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
Then the ToTensor
op will fail, since F.to_tensor
does expect a PIL image IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, since timm>=1.0.8 its MaybeToTensor
and works fine.. but on timm<1.0.8 it's indeed raising an error
TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>
I can add something like
if timm < 1.0.8 and isinstance(images, torch.Tensor):
images = images.cpu().numpy()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in a10bc0d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel version checks tend to be brittle, can you do a hasattr on the MaybeToTensor class existing? I think that should line up with its use in the transforms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rwightman, thanks for the note. I added a check for the class name, it looks not that elegant but I hope it's more robust 3d1a76e
@LysandreJik @molbap @rwightman Thanks for the reviews! I believe all comments have been addressed. Do you have anything else in mind? It would be nice to move it forward. |
Ok awesome! At this point just a second quick look from @ArthurZucker and we're good |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice ! This is kind of a perfect integration with Auto API, congrats!
normalize = ( | ||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std) | ||
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") | ||
else Lambda(lambda x: x) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do something explicit for this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you meant here, but I made it a bit more readable in 4edfe90 IMO. It's actually not related to the timm wrapper, it's the same in the original code.
# Convert old format to new format if needed from a PyTorch state_dict | ||
old_keys = [] | ||
new_keys = [] | ||
renamed_keys = {} | ||
renamed_gamma = {} | ||
renamed_beta = {} | ||
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " | ||
for key in state_dict.keys(): | ||
new_key = None | ||
if "gamma" in key: | ||
# We add only the first key as an example | ||
new_key = key.replace("gamma", "weight") | ||
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma | ||
if "beta" in key: | ||
# We add only the first key as an example | ||
new_key = key.replace("beta", "bias") | ||
renamed_beta[key] = new_key if not renamed_beta else renamed_beta | ||
if new_key: | ||
old_keys.append(key) | ||
new_keys.append(new_key) | ||
renamed_keys = {**renamed_gamma, **renamed_beta} | ||
if renamed_keys: | ||
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" | ||
for old_key, new_key in renamed_keys.items(): | ||
warning_msg += f"* `{old_key}` -> `{new_key}`\n" | ||
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." | ||
logger.info_once(warning_msg) | ||
for old_key, new_key in zip(old_keys, new_keys): | ||
state_dict[new_key] = state_dict.pop(old_key) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep there is #33192 kinda related
What does this PR do?
Adds a TimmWrapper set of classes such that timm models can be loaded in as transformer models into the library.
Continue of
General Usage
Pipeline
Timm models can now be used in the image classification (if a classification model) and image feature extraction pipelines
Trainer
Timm models can now be loaded and trained with the trainer class.
Example model trained with the trainer running the script command below:
https://huggingface.co/qubvel-hf/vit-base-beans
Other features enabled
output_hidden_states=True
oroutput_hidden_states=[1, 2, 3]
(to select specific hidden states)TODO
output_hidden_states
teststransformers
instead oftimm
, which architectures are affected?