Skip to content

Commit

Permalink
Exploring use of kwargs for timm model and transforms creation
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 21, 2025
1 parent 3df9010 commit da30662
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import os
from typing import Any, Dict, Optional, Tuple, Union

Expand All @@ -34,6 +34,8 @@

logger = logging.get_logger(__name__)

_DATA_ARG_KEYS = ("input_size", "img_size", "interpolation", "in_chans", "mean", "std", "use_train_size")


class TimmWrapperImageProcessor(BaseImageProcessor):
"""
Expand All @@ -58,11 +60,22 @@ def __init__(
requires_backends(self, "timm")
super().__init__(architecture=architecture)

self.data_config = timm.data.resolve_data_config(pretrained_cfg, model=None, verbose=False)
self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False)
data_arg_overrides = {}
for k in _DATA_ARG_KEYS:
if k in kwargs:
data_arg_overrides[k] = kwargs.pop(k)
self.data_config = timm.data.resolve_data_config(
args=data_arg_overrides, # will override values in pretrained_cfg
pretrained_cfg=pretrained_cfg,
model=None,
use_test_size=not data_arg_overrides.get("use_train_size", False),
verbose=False,
)

self.val_transforms = timm.data.create_transform(**self.data_config, is_training=False, **kwargs)

# useful for training, see examples/pytorch/image-classification/run_image_classification.py
self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True)
self.train_transforms = timm.data.create_transform(**self.data_config, is_training=True, **kwargs)

# If `ToTensor` is in the transforms, then the input should be numpy array or PIL image.
# Otherwise, the input can be a tensor. In later timm versions, `MaybeToTensor` is used
Expand All @@ -88,11 +101,26 @@ def get_image_processor_dict(
"""
Get the image processor dict for the model.
"""
requires_backends(cls, "timm")

image_processor_filename = kwargs.pop("image_processor_filename", "config.json")
return super().get_image_processor_dict(
image_processor_dict, kwargs = super().get_image_processor_dict(
pretrained_model_name_or_path, image_processor_filename=image_processor_filename, **kwargs
)

# Only pass through architecture and pretrained_cfg from config.json
image_processor_dict = {
"architecture": image_processor_dict["architecture"],
"pretrained_cfg": image_processor_dict["pretrained_cfg"],
}

# Merge kwargs that should be passed through to timm transform factory into image_processor_dict
for k in _DATA_ARG_KEYS + tuple(inspect.signature(timm.data.create_transform).parameters.keys()):
if k in kwargs:
image_processor_dict[k] = kwargs.pop(k)

return image_processor_dict, kwargs

def preprocess(
self,
images: ImageInput,
Expand Down
10 changes: 6 additions & 4 deletions src/transformers/models/timm_wrapper/modeling_timm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
Wrapper class for timm models to be used in transformers.
"""

def __init__(self, config: TimmWrapperConfig):
def __init__(self, config: TimmWrapperConfig, **kwargs):
super().__init__(config)
# using num_classes=0 to avoid creating classification head
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **kwargs)
self.post_init()

@add_start_docstrings_to_model_forward(TIMM_WRAPPER_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -240,7 +240,7 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
Wrapper class for timm models to be used in transformers for image classification.
"""

def __init__(self, config: TimmWrapperConfig):
def __init__(self, config: TimmWrapperConfig, **kwargs):
super().__init__(config)

if config.num_labels == 0:
Expand All @@ -250,7 +250,9 @@ def __init__(self, config: TimmWrapperConfig):
"or use `TimmWrapperModel` for feature extraction."
)

self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
self.timm_model = timm.create_model(
config.architecture, pretrained=False, num_classes=config.num_labels, **kwargs
)
self.num_labels = config.num_labels
self.post_init()

Expand Down

0 comments on commit da30662

Please sign in to comment.