diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index a99698b9b5e24e..2e53f66072b3f2 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -659,6 +659,8 @@
title: GLPN
- local: model_doc/hiera
title: Hiera
+ - local: model_doc/ijepa
+ title: I-JEPA
- local: model_doc/imagegpt
title: ImageGPT
- local: model_doc/levit
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 3da28bb54c83c9..783f575dd5242c 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -169,6 +169,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
+| [I-JEPA](model_doc/ijepa) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
| [Idefics2](model_doc/idefics2) | ✅ | ❌ | ❌ |
| [Idefics3](model_doc/idefics3) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md
new file mode 100644
index 00000000000000..9a0cd368a8188f
--- /dev/null
+++ b/docs/source/en/model_doc/ijepa.md
@@ -0,0 +1,78 @@
+
+
+# I-JEPA
+
+## Overview
+
+The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
+I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations.
+
+The abstract from the paper is the following:
+
+This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.
+
+This model was contributed by [jmtzt](https://huggingface.co/jmtzt).
+The original code can be found [here](https://github.com/facebookresearch/ijepa).
+
+## How to use
+
+Here is how to use this model for image feature extraction:
+
+```python
+import requests
+import torch
+from PIL import Image
+from torch.nn.functional import cosine_similarity
+
+from transformers import AutoModel, AutoProcessor
+
+url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
+url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg"
+image_1 = Image.open(requests.get(url_1, stream=True).raw)
+image_2 = Image.open(requests.get(url_2, stream=True).raw)
+
+model_id = "jmtzt/ijepa_vith14_1k"
+processor = AutoProcessor.from_pretrained(model_id)
+model = AutoModel.from_pretrained(model_id)
+
+@torch.no_grad()
+def infer(image):
+ inputs = processor(image, return_tensors="pt")
+ outputs = model(**inputs)
+ return outputs.last_hidden_state.mean(dim=1)
+
+
+embed_1 = infer(image_1)
+embed_2 = infer(image_2)
+
+similarity = cosine_similarity(embed_1, embed_2)
+print(similarity)
+```
+
+## IJepaConfig
+
+[[autodoc]] IJepaConfig
+
+## IJepaModel
+
+[[autodoc]] IJepaModel
+ - forward
+
+## IJepaForImageClassification
+
+[[autodoc]] IJepaForImageClassification
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 12f492ff29a5ee..ec8dea2735b531 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -235,6 +235,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
+* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
@@ -242,7 +243,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
-* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
+* [I-JEPA](https://huggingface.co/docs/transformers/model_doc/ijepa#transformers.IJepaModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
diff --git a/setup.py b/setup.py
index 2dd0c50fb8d693..2364e3b677fd6d 100644
--- a/setup.py
+++ b/setup.py
@@ -180,7 +180,7 @@
"timeout-decorator",
"tiktoken",
"timm<=1.0.11",
- "tokenizers>=0.20,<0.21",
+ "tokenizers>=0.21,<0.22",
"torch",
"torchaudio",
"torchvision",
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 970f32b9a88d92..e53d61eb7b2ec5 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -486,6 +486,7 @@
"models.idefics": ["IdeficsConfig"],
"models.idefics2": ["Idefics2Config"],
"models.idefics3": ["Idefics3Config"],
+ "models.ijepa": ["IJepaConfig"],
"models.imagegpt": ["ImageGPTConfig"],
"models.informer": ["InformerConfig"],
"models.instructblip": [
@@ -2471,6 +2472,13 @@
"Idefics3Processor",
]
)
+ _import_structure["models.ijepa"].extend(
+ [
+ "IJepaForImageClassification",
+ "IJepaModel",
+ "IJepaPreTrainedModel",
+ ]
+ )
_import_structure["models.imagegpt"].extend(
[
"ImageGPTForCausalImageModeling",
@@ -5378,6 +5386,7 @@
)
from .models.idefics2 import Idefics2Config
from .models.idefics3 import Idefics3Config
+ from .models.ijepa import IJepaConfig
from .models.imagegpt import ImageGPTConfig
from .models.informer import InformerConfig
from .models.instructblip import (
@@ -7197,6 +7206,11 @@
Idefics3PreTrainedModel,
Idefics3Processor,
)
+ from .models.ijepa import (
+ IJepaForImageClassification,
+ IJepaModel,
+ IJepaPreTrainedModel,
+ )
from .models.imagegpt import (
ImageGPTForCausalImageModeling,
ImageGPTForImageClassification,
diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py
index 9d4d90f11221db..23f2177b25d529 100644
--- a/src/transformers/cache_utils.py
+++ b/src/transformers/cache_utils.py
@@ -12,7 +12,6 @@
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
- is_quanto_available,
is_torchdynamo_compiling,
logging,
)
@@ -790,17 +789,6 @@ def __init__(self, cache_config: CacheConfig) -> None:
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
)
from optimum.quanto import MaxOptimizer, qint2, qint4
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- quanto_version = version.parse(importlib.metadata.version("quanto"))
- if quanto_version < version.parse("0.2.0"):
- raise ImportError(
- f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
- f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`"
- )
- from quanto import MaxOptimizer, qint2, qint4
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
@@ -824,16 +812,6 @@ def _quantize(self, tensor, axis):
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
return qtensor
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- from quanto import AffineQuantizer
-
- scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
- qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
-
- return qtensor
def _dequantize(self, qtensor):
return qtensor.dequantize()
diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py
index dcf216783befaa..85345cc8e5889d 100644
--- a/src/transformers/dependency_versions_table.py
+++ b/src/transformers/dependency_versions_table.py
@@ -85,7 +85,7 @@
"timeout-decorator": "timeout-decorator",
"tiktoken": "tiktoken",
"timm": "timm<=1.0.11",
- "tokenizers": "tokenizers>=0.20,<0.21",
+ "tokenizers": "tokenizers>=0.21,<0.22",
"torch": "torch",
"torchaudio": "torchaudio",
"torchvision": "torchvision",
diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py
index 5ef0c0eb81c87a..015cbebaa8e5dc 100644
--- a/src/transformers/generation/utils.py
+++ b/src/transformers/generation/utils.py
@@ -45,7 +45,6 @@
is_accelerate_available,
is_hqq_available,
is_optimum_quanto_available,
- is_quanto_available,
is_torchdynamo_compiling,
logging,
)
@@ -1787,7 +1786,7 @@ def _prepare_cache_for_generation(
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
- if cache_config.backend == "quanto" and not (is_optimum_quanto_available() or is_quanto_available()):
+ if cache_config.backend == "quanto" and not is_optimum_quanto_available():
raise ImportError(
"You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. "
"Please install it via with `pip install optimum-quanto`"
diff --git a/src/transformers/integrations/quanto.py b/src/transformers/integrations/quanto.py
index 27b32de63bfe55..1c5702321937da 100644
--- a/src/transformers/integrations/quanto.py
+++ b/src/transformers/integrations/quanto.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ..utils import is_optimum_quanto_available, is_quanto_available, is_torch_available, logging
+from ..utils import is_optimum_quanto_available, is_torch_available, logging
if is_torch_available():
@@ -50,11 +50,6 @@ def replace_with_quanto_layers(
if is_optimum_quanto_available():
from optimum.quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py
index cca6d548cdf3ac..7562649be753bb 100644
--- a/src/transformers/modeling_gguf_pytorch_utils.py
+++ b/src/transformers/modeling_gguf_pytorch_utils.py
@@ -291,7 +291,6 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
# FIXME: Currnetly this implementation is only for flan-t5 architecture.
# It needs to be developed for supporting legacy t5.
elif "t5" in architecture or "t5encoder" in architecture:
- parsed_parameters["config"]["tie_word_embeddings"] = False
parsed_parameters["config"]["is_gated_act"] = True
updated_architecture = "t5"
else:
@@ -326,6 +325,12 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES:
raise ValueError(f"Architecture {architecture + model_size} not supported")
+ # Handle tie_word_embeddings, if lm_head.weight is not present in tensors,
+ # tie_word_embeddings is true otherwise false
+ parsed_parameters["config"]["tie_word_embeddings"] = all(
+ "output.weight" != tensor.name for tensor in reader.tensors
+ )
+
# List all key-value pairs in a columnized format
for gguf_key, field in reader.fields.items():
gguf_key = gguf_key.replace(architecture, updated_architecture)
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 242fe286a0aa48..19c2328d567643 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -118,6 +118,7 @@
idefics,
idefics2,
idefics3,
+ ijepa,
imagegpt,
informer,
instructblip,
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 19a5c5db2cdaee..74bd3977099608 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -136,6 +136,7 @@
("idefics", "IdeficsConfig"),
("idefics2", "Idefics2Config"),
("idefics3", "Idefics3Config"),
+ ("ijepa", "IJepaConfig"),
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
@@ -442,6 +443,7 @@
("idefics", "IDEFICS"),
("idefics2", "Idefics2"),
("idefics3", "Idefics3"),
+ ("ijepa", "I-JEPA"),
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),
diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py
index 11ae15ca461e79..e19c8efd205552 100644
--- a/src/transformers/models/auto/image_processing_auto.py
+++ b/src/transformers/models/auto/image_processing_auto.py
@@ -90,6 +90,7 @@
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
+ ("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor",)),
("instructblipvideo", ("InstructBlipVideoImageProcessor",)),
@@ -433,7 +434,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if image_processor_class is None and image_processor_auto_map is None:
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ **kwargs,
)
# It could be in `config.image_processor_type``
image_processor_class = getattr(config, "image_processor_type", None)
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index b93f1b6392b1d9..8c44e93cafdfc8 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -133,6 +133,7 @@
("idefics", "IdeficsModel"),
("idefics2", "Idefics2Model"),
("idefics3", "Idefics3Model"),
+ ("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
@@ -580,6 +581,7 @@
("focalnet", "FocalNetModel"),
("glpn", "GLPNModel"),
("hiera", "HieraModel"),
+ ("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("mllama", "MllamaVisionModel"),
@@ -658,6 +660,7 @@
("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"),
("hiera", "HieraForImageClassification"),
+ ("ijepa", "IJepaForImageClassification"),
("imagegpt", "ImageGPTForImageClassification"),
(
"levit",
diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py
index 2e32912421dc5b..ed8ddd3c47dea3 100644
--- a/src/transformers/models/blip_2/modeling_blip_2.py
+++ b/src/transformers/models/blip_2/modeling_blip_2.py
@@ -2311,7 +2311,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
- start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
+ start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
diff --git a/src/transformers/models/ijepa/__init__.py b/src/transformers/models/ijepa/__init__.py
new file mode 100644
index 00000000000000..efc8c90b17628d
--- /dev/null
+++ b/src/transformers/models/ijepa/__init__.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ is_torch_available,
+)
+
+
+_import_structure = {"configuration_ijepa": ["IJepaConfig"]}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ pass
+else:
+ _import_structure["modeling_ijepa"] = [
+ "IJepaForImageClassification",
+ "IJepaModel",
+ "IJepaPreTrainedModel",
+ ]
+
+if TYPE_CHECKING:
+ from .configuration_ijepa import IJepaConfig
+
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ pass
+ else:
+ from .modeling_ijepa import (
+ IJepaForImageClassification,
+ IJepaModel,
+ IJepaPreTrainedModel,
+ )
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
diff --git a/src/transformers/models/ijepa/configuration_ijepa.py b/src/transformers/models/ijepa/configuration_ijepa.py
new file mode 100644
index 00000000000000..26378e6e81d9ce
--- /dev/null
+++ b/src/transformers/models/ijepa/configuration_ijepa.py
@@ -0,0 +1,108 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""I-JEPA model configuration"""
+
+from ...configuration_utils import PretrainedConfig
+
+
+class IJepaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`IJepaModel`]. It is used to instantiate an IJEPA
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the I-JEPA
+ [google/ijepa-base-patch16-224](https://huggingface.co/google/ijepa-base-patch16-224) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ patch_size (`int`, *optional*, defaults to 16):
+ The size (resolution) of each patch.
+ num_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ qkv_bias (`bool`, *optional*, defaults to `True`):
+ Whether to add a bias to the queries, keys and values.
+
+ Example:
+
+ ```python
+ >>> from transformers import IJepaConfig, IJepaModel
+
+ >>> # Initializing a IJEPA ijepa-base-patch16-224 style configuration
+ >>> configuration = IJepaConfig()
+
+ >>> # Initializing a model (with random weights) from the ijepa-base-patch16-224 style configuration
+ >>> model = IJepaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "ijepa"
+
+ def __init__(
+ self,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.0,
+ attention_probs_dropout_prob=0.0,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ image_size=224,
+ patch_size=16,
+ num_channels=3,
+ qkv_bias=True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.qkv_bias = qkv_bias
diff --git a/src/transformers/models/ijepa/convert_ijepa_to_hf.py b/src/transformers/models/ijepa/convert_ijepa_to_hf.py
new file mode 100644
index 00000000000000..5c15a72ff88847
--- /dev/null
+++ b/src/transformers/models/ijepa/convert_ijepa_to_hf.py
@@ -0,0 +1,267 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Convert IJEPA checkpoints from the original repository.
+
+URL: https://github.com/facebookresearch/ijepa
+"""
+
+import argparse
+import gc
+import re
+from pathlib import Path
+
+import requests
+import torch
+from PIL import Image
+
+from transformers import (
+ IJepaConfig,
+ IJepaModel,
+ ViTImageProcessor,
+)
+from transformers.utils import logging
+
+
+logging.set_verbosity_info()
+logger = logging.get_logger(__name__)
+
+# fmt: off
+ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
+ # Projection layer + position embeddings
+ r"pos_embed": r"embeddings.position_embeddings",
+ r"patch_embed.proj.weight": r"embeddings.patch_embeddings.projection.weight",
+ r"patch_embed.proj.bias": r"embeddings.patch_embeddings.projection.bias",
+
+ # Encoder layers: Layernorms, Attention, Feedforward layers
+ r"blocks.(\d+).norm1.weight": r"encoder.layer.\1.layernorm_before.weight",
+ r"blocks.(\d+).norm1.bias": r"encoder.layer.\1.layernorm_before.bias",
+ r"blocks.(\d+).attn.proj.weight": r"encoder.layer.\1.attention.output.dense.weight",
+ r"blocks.(\d+).attn.proj.bias": r"encoder.layer.\1.attention.output.dense.bias",
+ r"blocks.(\d+).norm2.weight": r"encoder.layer.\1.layernorm_after.weight",
+ r"blocks.(\d+).norm2.bias": r"encoder.layer.\1.layernorm_after.bias",
+ r"blocks.(\d+).mlp.fc1.weight": r"encoder.layer.\1.intermediate.dense.weight",
+ r"blocks.(\d+).mlp.fc1.bias": r"encoder.layer.\1.intermediate.dense.bias",
+ r"blocks.(\d+).mlp.fc2.weight": r"encoder.layer.\1.output.dense.weight",
+ r"blocks.(\d+).mlp.fc2.bias": r"encoder.layer.\1.output.dense.bias",
+
+ # Layernorm + pooler
+ r"norm.weight": r"layernorm.weight",
+ r"norm.bias": r"layernorm.bias",
+}
+# fmt: on
+
+
+def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
+ """
+ Converts old keys to new keys using the mapping and dynamically removes the 'ijepa.' prefix if necessary.
+
+ Args:
+ state_dict_keys (dict): The keys from the state_dict to convert.
+
+ Returns:
+ dict: A mapping from old keys to new keys.
+ """
+ output_dict = {}
+ if state_dict_keys is not None:
+ old_text = "\n".join(state_dict_keys)
+ new_text = old_text
+
+ # Apply regex-based mapping
+ for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
+ if replacement is None:
+ new_text = re.sub(pattern, "", new_text) # Skip the key
+ continue
+ new_text = re.sub(pattern, replacement, new_text)
+
+ output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
+
+ return output_dict
+
+
+# we split up the matrix of each encoder layer into queries, keys and values
+def read_in_q_k_v(state_dict, config):
+ for i in range(config.num_hidden_layers):
+ # read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
+ in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
+ in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
+ # next, add query, keys and values (in that order) to the state dict
+ state_dict[f"encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[: config.hidden_size, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
+ config.hidden_size : config.hidden_size * 2, :
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
+ config.hidden_size : config.hidden_size * 2
+ ]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[-config.hidden_size :, :]
+ state_dict[f"encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
+
+
+def rename_key(dct, old, new):
+ val = dct.pop(old)
+ dct[new] = val
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ im = Image.open(requests.get(url, stream=True).raw)
+ return im
+
+
+def get_ijepa_config(model_name):
+ patch_size = int(model_name.split("_")[1][4:])
+ config = IJepaConfig(patch_size=patch_size)
+ if "vith" in model_name:
+ config.hidden_size = 1280
+ config.num_hidden_layers = 32
+ config.num_attention_heads = 16
+ config.layer_norm_eps = 1e-6
+ config.mlp_ratio = 4
+ config.intermediate_size = 5120
+ if model_name == "ijepa_vith16_1k":
+ config.image_size = 448
+ elif "vitg" in model_name:
+ config.hidden_size = 1408
+ config.num_hidden_layers = 40
+ config.num_attention_heads = 16
+ config.layer_norm_eps = 1e-6
+ config.mlp_ratio = 48 / 11
+ config.intermediate_size = 6144
+ else:
+ raise ValueError("Model not supported, only supports huge and giant models.")
+ return config
+
+
+@torch.no_grad()
+def write_model(model_name, output_dir, safe_serialization, push_to_hub, verify_logits):
+ """
+ Copy/paste/tweak model's weights to our IJEPA structure.
+ """
+
+ # define default IJEPA configuration
+ config = get_ijepa_config(model_name)
+
+ checkpoint_mapping = {
+ "ijepa_vith14_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar",
+ "ijepa_vith14_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar",
+ "ijepa_vith16_1k": "https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar",
+ "ijepa_vitg16_22k": "https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar",
+ }
+
+ # Load original checkpoint
+ checkpoint_url = checkpoint_mapping[model_name]
+ original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["encoder"]
+ original_state_dict = {k.replace("module.", ""): v for k, v in original_state_dict.items()}
+
+ # Rename keys
+ state_dict = original_state_dict.copy()
+ new_keys = convert_old_keys_to_new_keys(state_dict.keys())
+ for old_key, new_key in new_keys.items():
+ rename_key(state_dict, old_key, new_key)
+ read_in_q_k_v(state_dict, config)
+
+ # load HuggingFace model
+ model = IJepaModel(config, add_pooling_layer=False).eval()
+ model.load_state_dict(state_dict)
+ size = {"height": config.image_size, "width": config.image_size}
+ image_processor = ViTImageProcessor(size=size)
+
+ if verify_logits:
+ # Check outputs on an image, prepared by ViTImageProcessor
+ encoding = image_processor(images=prepare_img(), return_tensors="pt")
+ pixel_values = encoding["pixel_values"]
+ with torch.no_grad():
+ outputs = model(pixel_values)
+
+ expected_slices = {
+ "ijepa_vith14_1k": torch.Tensor(
+ [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
+ ),
+ "ijepa_vith14_22k": torch.Tensor(
+ [[0.0358, -0.0045, -0.2154], [0.0418, -0.0246, 0.0108], [0.2529, -0.0345, -0.0246]]
+ ),
+ "ijepa_vith16_1k": torch.Tensor(
+ [[0.5145, -0.1259, 0.0615], [0.1132, 0.0028, -0.0496], [1.1586, -0.0056, -0.0387]]
+ ),
+ "ijepa_vitg16_22k": torch.Tensor(
+ [[0.0512, -0.0510, -0.0649], [0.1972, 0.0380, -0.0790], [0.1667, -0.0834, -0.1240]]
+ ),
+ }
+
+ assert torch.allclose(
+ expected_slices[model_name],
+ outputs.last_hidden_state[0, :3, :3],
+ atol=1e-4,
+ )
+
+ if output_dir:
+ Path(output_dir).mkdir(exist_ok=True)
+ print(f"Saving model {model_name} to {output_dir}")
+ image_processor.save_pretrained(output_dir, safe_serialization=safe_serialization)
+ model.save_pretrained(output_dir, safe_serialization=safe_serialization)
+
+ if push_to_hub:
+ image_processor.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization)
+ model.push_to_hub(repo_id=f"jmtzt/{model_name}", safe_serialization=safe_serialization)
+
+ if output_dir:
+ del model, state_dict
+ gc.collect()
+ print("Reloading the model to check if it's saved correctly.")
+ IJepaModel.from_pretrained(output_dir, device_map="auto")
+ print("Model reloaded successfully.")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ # Required parameters
+ parser.add_argument(
+ "--model_name",
+ default="ijepa_vith14_1k",
+ type=str,
+ choices=[
+ "ijepa_vith14_1k",
+ "ijepa_vith14_22k",
+ "ijepa_vith16_1k",
+ "ijepa_vitg16_22k",
+ ],
+ help="Name of the model you'd like to convert.",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Path to the output PyTorch model directory.",
+ )
+ parser.add_argument(
+ "--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
+ )
+ parser.add_argument(
+ "--push_to_hub",
+ action="store_true",
+ help="Whether or not to push the model to the 🤗 Hub.",
+ )
+ parser.add_argument(
+ "--verify_logits", action="store_false", help="Whether or not to verify logits after conversion."
+ )
+
+ parser.set_defaults()
+ args = parser.parse_args()
+ write_model(args.model_name, args.output_dir, args.safe_serialization, args.push_to_hub, args.verify_logits)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py
new file mode 100644
index 00000000000000..df254455bad5ab
--- /dev/null
+++ b/src/transformers/models/ijepa/modeling_ijepa.py
@@ -0,0 +1,751 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/ijepa/modular_ijepa.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_ijepa.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+import collections.abc
+import math
+from typing import Dict, List, Optional, Set, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ torch_int,
+)
+from .configuration_ijepa import IJepaConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k"
+
+# General docstring
+_CONFIG_FOR_DOC = "IJepaConfig"
+
+
+class IJepaPatchEmbeddings(nn.Module):
+ """
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+ Transformer.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ image_size, patch_size = config.image_size, config.patch_size
+ num_channels, hidden_size = config.num_channels, config.hidden_size
+
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.num_patches = num_patches
+
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
+ batch_size, num_channels, height, width = pixel_values.shape
+ if num_channels != self.num_channels:
+ raise ValueError(
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
+ f" Expected {self.num_channels} but got {num_channels}."
+ )
+ if not interpolate_pos_encoding:
+ if height != self.image_size[0] or width != self.image_size[1]:
+ raise ValueError(
+ f"Input image size ({height}*{width}) doesn't match model"
+ f" ({self.image_size[0]}*{self.image_size[1]})."
+ )
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
+ return embeddings
+
+
+class IJepaEmbeddings(nn.Module):
+ """
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
+ """
+
+ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
+ super().__init__()
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
+ self.patch_embeddings = IJepaPatchEmbeddings(config)
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.patch_size = config.patch_size
+ self.config = config
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embeddings.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ patch_pos_embed = self.position_embeddings
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return patch_pos_embed
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ if bool_masked_pos is not None:
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class IJepaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = IJepaConfig
+ base_model_prefix = "ijepa"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, IJepaEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+
+class IJepaSelfAttention(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
+ f"heads {config.num_attention_heads}."
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ return outputs
+
+
+class IJepaSdpaSelfAttention(IJepaSelfAttention):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__(config)
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ if output_attentions or head_mask is not None:
+ logger.warning_once(
+ "`IJepaSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
+ "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
+ 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ )
+
+ mixed_query_layer = self.query(hidden_states)
+
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
+ query_layer,
+ key_layer,
+ value_layer,
+ head_mask,
+ self.attention_probs_dropout_prob if self.training else 0.0,
+ is_causal=False,
+ scale=None,
+ )
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ return context_layer, None
+
+
+class IJepaSelfOutput(nn.Module):
+ """
+ The residual connection is defined in IJepaLayer instead of here (as is the case with other models), due to the
+ layernorm applied before each block.
+ """
+
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ return hidden_states
+
+
+class IJepaAttention(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.attention = IJepaSelfAttention(config)
+ self.output = IJepaSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads: Set[int]) -> None:
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.attention.query = prune_linear_layer(self.attention.query, index)
+ self.attention.key = prune_linear_layer(self.attention.key, index)
+ self.attention.value = prune_linear_layer(self.attention.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
+
+ attention_output = self.output(self_outputs[0], hidden_states)
+
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class IJepaSdpaAttention(IJepaAttention):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__(config)
+ self.attention = IJepaSdpaSelfAttention(config)
+
+
+class IJepaIntermediate(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+
+ return hidden_states
+
+
+class IJepaOutput(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ hidden_states = hidden_states + input_tensor
+
+ return hidden_states
+
+
+IJEPA_ATTENTION_CLASSES = {
+ "eager": IJepaAttention,
+ "sdpa": IJepaSdpaAttention,
+}
+
+
+class IJepaLayer(nn.Module):
+ """This corresponds to the Block class in the timm implementation."""
+
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = IJEPA_ATTENTION_CLASSES[config._attn_implementation](config)
+ self.intermediate = IJepaIntermediate(config)
+ self.output = IJepaOutput(config)
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
+ self_attention_outputs = self.attention(
+ self.layernorm_before(hidden_states), # in IJepa, layernorm is applied before self-attention
+ head_mask,
+ output_attentions=output_attentions,
+ )
+ attention_output = self_attention_outputs[0]
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ # first residual connection
+ hidden_states = attention_output + hidden_states
+
+ # in IJepa, layernorm is also applied after self-attention
+ layer_output = self.layernorm_after(hidden_states)
+ layer_output = self.intermediate(layer_output)
+
+ # second residual connection is done here
+ layer_output = self.output(layer_output, hidden_states)
+
+ outputs = (layer_output,) + outputs
+
+ return outputs
+
+
+class IJepaEncoder(nn.Module):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([IJepaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ return_dict: bool = True,
+ ) -> Union[tuple, BaseModelOutput]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ layer_head_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class IJepaPooler(nn.Module):
+ def __init__(self, config: IJepaConfig):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+IJEPA_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`IJepaImageProcessor.__call__`]
+ for details.
+
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ interpolate_pos_encoding (`bool`, *optional*):
+ Whether to interpolate the pre-trained position encodings.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
+
+
+IJEPA_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`IJepaConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.",
+ IJEPA_START_DOCSTRING,
+)
+class IJepaModel(IJepaPreTrainedModel):
+ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
+ self.encoder = IJepaEncoder(config)
+
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.pooler = IJepaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> IJepaPatchEmbeddings:
+ return self.embeddings.patch_embeddings
+
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPooling,
+ config_class=_CONFIG_FOR_DOC,
+ modality="vision",
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None:
+ raise ValueError("You have to specify pixel_values")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
+ if pixel_values.dtype != expected_dtype:
+ pixel_values = pixel_values.to(expected_dtype)
+
+ embedding_output = self.embeddings(
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ sequence_output = self.layernorm(sequence_output)
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
+ return head_outputs + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+# Image classification docstring
+_IMAGE_CLASS_CHECKPOINT = "google/ijepa-base-patch16-224"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
+
+
+@add_start_docstrings(
+ """
+ IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
+ e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """,
+ IJEPA_START_DOCSTRING,
+)
+class IJepaForImageClassification(IJepaPreTrainedModel):
+ def __init__(self, config: IJepaConfig) -> None:
+ super().__init__(config)
+
+ self.num_labels = config.num_labels
+ self.ijepa = IJepaModel(config, add_pooling_layer=False)
+
+ # Classifier head
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(IJEPA_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
+ output_type=ImageClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
+ )
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ijepa(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output.mean(dim=1))
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = ["IJepaPreTrainedModel", "IJepaModel", "IJepaForImageClassification"]
diff --git a/src/transformers/models/ijepa/modular_ijepa.py b/src/transformers/models/ijepa/modular_ijepa.py
new file mode 100644
index 00000000000000..efbd71d91342fd
--- /dev/null
+++ b/src/transformers/models/ijepa/modular_ijepa.py
@@ -0,0 +1,255 @@
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.models.ijepa.configuration_ijepa import IJepaConfig
+
+from ...modeling_outputs import ImageClassifierOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ torch_int,
+)
+from ..vit.modeling_vit import (
+ ViTEmbeddings,
+ ViTForImageClassification,
+ ViTModel,
+)
+
+
+_CHECKPOINT_FOR_DOC = "facebook/ijepa_vith14_1k"
+
+
+class IJepaEmbeddings(ViTEmbeddings):
+ def __init__(self, config: IJepaConfig, use_mask_token: bool = False) -> None:
+ super().__init__(config, use_mask_token)
+ # Remove cls_token from IJepaEmbeddings, as it is not used in the model
+ del self.cls_token
+ num_patches = self.patch_embeddings.num_patches
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches, config.hidden_size))
+
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+ """
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
+ images. This method is also adapted to support torch.jit tracing.
+
+ Adapted from:
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
+ """
+
+ num_patches = embeddings.shape[1]
+ num_positions = self.position_embeddings.shape[1]
+
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
+ return self.position_embeddings
+
+ patch_pos_embed = self.position_embeddings
+
+ dim = embeddings.shape[-1]
+
+ new_height = height // self.patch_size
+ new_width = width // self.patch_size
+
+ sqrt_num_positions = torch_int(num_positions**0.5)
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed,
+ size=(new_height, new_width),
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+
+ return patch_pos_embed
+
+ def forward(
+ self,
+ pixel_values: torch.Tensor,
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
+ interpolate_pos_encoding: bool = False,
+ ) -> torch.Tensor:
+ batch_size, _, height, width = pixel_values.shape
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
+
+ if bool_masked_pos is not None:
+ seq_length = embeddings.shape[1]
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
+ # replace the masked visual tokens by mask_tokens
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+ # add positional encoding to each token
+ if interpolate_pos_encoding:
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+ else:
+ embeddings = embeddings + self.position_embeddings
+
+ embeddings = self.dropout(embeddings)
+
+ return embeddings
+
+
+class IJepaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = IJepaConfig
+ base_model_prefix = "ijepa"
+ main_input_name = "pixel_values"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["IJepaEmbeddings", "IJepaLayer"]
+ _supports_sdpa = True
+
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
+ """Initialize the weights"""
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
+ # `trunc_normal_cpu` not implemented in `half` issues
+ module.weight.data = nn.init.trunc_normal_(
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
+ ).to(module.weight.dtype)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, IJepaEmbeddings):
+ module.position_embeddings.data = nn.init.trunc_normal_(
+ module.position_embeddings.data.to(torch.float32),
+ mean=0.0,
+ std=self.config.initializer_range,
+ ).to(module.position_embeddings.dtype)
+
+
+_EXPECTED_OUTPUT_SHAPE = [1, 256, 1280]
+
+IJEPA_START_DOCSTRING = r"""
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+ behavior.
+
+ Parameters:
+ config ([`IJepaConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare IJepa Model transformer outputting raw hidden-states without any specific head on top.",
+ IJEPA_START_DOCSTRING,
+)
+class IJepaModel(IJepaPreTrainedModel, ViTModel):
+ def __init__(self, config: IJepaConfig, add_pooling_layer: bool = False, use_mask_token: bool = False):
+ super().__init__(config)
+ self.config = config
+ self.embeddings = IJepaEmbeddings(config, use_mask_token=use_mask_token)
+
+
+_IMAGE_CLASS_CHECKPOINT = "jmtzt/ijepa_vith14_1k"
+_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
+
+
+@add_start_docstrings(
+ """
+ IJepa Model transformer with an image classification head on top (a linear layer on top of the final hidden states)
+ e.g. for ImageNet.
+
+
+
+ Note that it's possible to fine-tune IJepa on higher resolution images than the ones it has been trained on, by
+ setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
+ position embeddings to the higher resolution.
+
+
+ """,
+ IJEPA_START_DOCSTRING,
+)
+class IJepaForImageClassification(IJepaPreTrainedModel, ViTForImageClassification):
+ def __init__(self, config: IJepaConfig):
+ super().__init__(config)
+ self.ijepa = IJepaModel(config, add_pooling_layer=False)
+ self.post_init()
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ interpolate_pos_encoding: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[tuple, ImageClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.ijepa(
+ pixel_values,
+ head_mask=head_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ interpolate_pos_encoding=interpolate_pos_encoding,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.classifier(sequence_output.mean(dim=1))
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return ImageClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+__all__ = [
+ "IJepaPreTrainedModel",
+ "IJepaModel",
+ "IJepaForImageClassification",
+]
diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py
index a63393ab1ddcc0..acce24cc42f5d8 100644
--- a/src/transformers/models/instructblip/modeling_instructblip.py
+++ b/src/transformers/models/instructblip/modeling_instructblip.py
@@ -1593,7 +1593,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
- start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
+ start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
index e922d1e3f26228..e91b05bc015263 100644
--- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
+++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py
@@ -1628,7 +1628,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
- start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
+ start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py
index 126d81b6d3dcce..670cc4a086fb1c 100644
--- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py
+++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py
@@ -441,7 +441,11 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
+<<<<<<< HEAD
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
+=======
+ start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
+>>>>>>> upstream/main
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)
diff --git a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py
index 266812b3972dff..1a89ade8fa6dbd 100644
--- a/src/transformers/models/mistral/convert_mistral_weights_to_hf.py
+++ b/src/transformers/models/mistral/convert_mistral_weights_to_hf.py
@@ -12,20 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
-import gc
import json
import os
-import shutil
+import re
import warnings
import torch
-from safetensors.torch import load_file as safe_load_file
+from safetensors.torch import load_file
-from transformers import (
- LlamaTokenizer,
- MistralConfig,
- MistralForCausalLM,
-)
+from transformers import LlamaTokenizer, MistralConfig, MistralForCausalLM
try:
@@ -39,32 +34,40 @@
)
tokenizer_class = LlamaTokenizer
-"""
-Sample usage:
+# fmt: off
+STATE_DICT_MAPPING = {
+ # CausalLM keys
+ r"^output.weight": r"lm_head.weight",
-```
-python src/transformers/models/mistral/convert_mistral_weights_to_hf.py \
- --input_dir /path/to/downloaded/mistral/weights --model_size 7B --output_dir /output/path
-```
+ # Model keys
+ r"^norm.weight": r"model.norm.weight",
+ r"^tok_embeddings.weight": r"model.embed_tokens.weight",
-Thereafter, models can be loaded via:
+ # Layers keys
+ r"^layers.(\d+).attention_norm.weight": r"model.layers.\1.input_layernorm.weight",
+ r"^layers.(\d+).ffn_norm.weight": r"model.layers.\1.post_attention_layernorm.weight",
-```py
-from transformers import MistralForCausalLM, LlamaTokenizer
+ # Attention keys
+ r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.layers.\1.self_attn.\2_proj.weight",
-model = MistralForCausalLM.from_pretrained("/output/path")
-tokenizer = LlamaTokenizer.from_pretrained("/output/path")
-```
-Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
-come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
-"""
+ # MLP keys
+ r"^layers.(\d+).feed_forward.w1.weight": r"model.layers.\1.mlp.gate_proj.weight",
+ r"^layers.(\d+).feed_forward.w2.weight": r"model.layers.\1.mlp.down_proj.weight",
+ r"^layers.(\d+).feed_forward.w3.weight": r"model.layers.\1.mlp.up_proj.weight",
+}
+# fmt: on
-NUM_SHARDS = {"7B": 1}
+def map_old_key_to_new(old_key):
+ """Map of a key of the original state dict to the equivalent key in HF format"""
+ for pattern, replacement in STATE_DICT_MAPPING.items():
+ new_key, n_replace = re.subn(pattern, replacement, old_key)
+ # Early exit of the loop
+ if n_replace > 0:
+ return new_key
-def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):
- return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)
+ raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")
def read_json(path):
@@ -72,218 +75,201 @@ def read_json(path):
return json.load(f)
-def write_json(text, path):
- with open(path, "w") as f:
- json.dump(text, f)
+def permute_for_rope(tensor, n_heads, dim1, dim2):
+ """Permute the weights for the ROPE formulation."""
+ tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
+ tensor = tensor.transpose(1, 2)
+ tensor = tensor.reshape(dim1, dim2)
+ return tensor
-def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True, is_v3=False):
- # for backward compatibility, before you needed the repo to be called `my_repo/model_size`
- if not os.path.isfile(os.path.join(input_base_path, "params.json")):
- input_base_path = os.path.join(input_base_path, model_size)
+def convert_state_dict(original_state_dict: dict, config: MistralConfig):
+ """Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
+ new_dict = {}
- os.makedirs(model_path, exist_ok=True)
- tmp_model_path = os.path.join(model_path, "tmp")
- os.makedirs(tmp_model_path, exist_ok=True)
+ n_heads = config.num_attention_heads
+ dim = config.hidden_size
+ dims_per_head = dim // n_heads
+ num_key_value_heads = config.num_key_value_heads
+ key_value_dim = dims_per_head * num_key_value_heads
- params = read_json(os.path.join(input_base_path, "params.json"))
- num_shards = NUM_SHARDS[model_size]
+ for old_key, tensor in original_state_dict.items():
+ new_key = map_old_key_to_new(old_key)
- sliding_window = params.get("sliding_window", None)
+ if "q_proj" in new_key:
+ tensor = tensor.view(n_heads, dims_per_head, dim).reshape(dim, dim)
+ tensor = permute_for_rope(tensor, n_heads, dim, dim)
+ elif "k_proj" in new_key:
+ tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim)
+ tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim)
+ elif "v_proj" in new_key:
+ tensor = tensor.view(num_key_value_heads, dims_per_head, dim).reshape(key_value_dim, dim)
- # For some reason this is a string in the params.json
- if sliding_window is not None:
- sliding_window = int(sliding_window)
+ new_dict[new_key] = tensor
+ return new_dict
- n_layers = params["n_layers"]
- n_heads = params["n_heads"]
- n_heads_per_shard = n_heads // num_shards
- dim = params["dim"]
+
+def get_concat_dim(key):
+ """Return the dimension to concatenate the weights on."""
+ concat_dim_1 = [
+ r"model.embed_tokens.weight",
+ r"model.layers.(\d+).self_attn.o_proj.weight",
+ r"model.layers.(\d+).mlp.down_proj.weight",
+ ]
+ if any(re.search(pattern, key) for pattern in concat_dim_1):
+ return 1
+ return 0
+
+
+def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig):
+ """Convert the state dict, when a single `nn.Module` is sharded accross different files."""
+ new_dict = {}
+
+ num_shards = len(loaded_shards)
+
+ n_heads = config.num_attention_heads
+ dim = config.hidden_size
dims_per_head = dim // n_heads
- base = params.get("rope_theta", 10000.0)
- inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
- max_position_embeddings = 4096 * 8
-
- if tokenizer_path is not None:
- tokenizer = tokenizer_class(tokenizer_path + ".v3" if is_v3 else "")
- tokenizer.save_pretrained(model_path)
- vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000
-
- if "n_kv_heads" in params:
- num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
- num_local_key_value_heads = num_key_value_heads // num_shards
- key_value_dim = dims_per_head * num_local_key_value_heads
- else: # compatibility with other checkpoints
- num_key_value_heads = n_heads
- num_local_key_value_heads = n_heads_per_shard
- key_value_dim = dim
-
- # permute for sliced rotary
- def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
- return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
-
- print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
-
- # Load weights - for v3 models the consolidated weights are in a single file format in safetensors
- if is_v3:
- loaded = [safe_load_file(os.path.join(input_base_path, "consolidated.safetensors"))]
- else:
- loaded = [
- torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
- for i in range(num_shards)
- ]
- param_count = 0
- index_dict = {"weight_map": {}}
- for layer_i in range(n_layers):
- filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
-
- # Sharded
- # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
- # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
- # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
-
- state_dict = {
- f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
- f"layers.{layer_i}.attention_norm.weight"
- ].clone(),
- f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
- f"layers.{layer_i}.ffn_norm.weight"
- ].clone(),
- }
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
- torch.cat(
- [
- loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
- for i in range(num_shards)
- ],
- dim=0,
+ num_key_value_heads = config.num_key_value_heads
+ n_heads_per_shard = n_heads // num_shards
+ num_local_key_value_heads = num_key_value_heads // num_shards
+ key_value_dim = dim if n_heads == num_key_value_heads else dims_per_head * num_local_key_value_heads
+
+ original_keys = loaded_shards[0].keys()
+ for old_key in original_keys:
+ new_key = map_old_key_to_new(old_key)
+ cat_dim = get_concat_dim(new_key)
+
+ if "q_proj" in new_key:
+ tensor = torch.cat(
+ [shard.pop(old_key).view(n_heads_per_shard, dims_per_head, dim) for shard in loaded_shards],
+ dim=cat_dim,
).reshape(dim, dim)
- )
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
- torch.cat(
- [
- loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(
- num_local_key_value_heads, dims_per_head, dim
- )
- for i in range(num_shards)
- ],
- dim=0,
- ).reshape(key_value_dim, dim),
- num_key_value_heads,
- key_value_dim,
- dim,
- )
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
- [
- loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(num_local_key_value_heads, dims_per_head, dim)
- for i in range(num_shards)
- ],
- dim=0,
- ).reshape(key_value_dim, dim)
-
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
- )
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
- )
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
- )
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
- )
-
- state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
- for k, v in state_dict.items():
- index_dict["weight_map"][k] = filename
- param_count += v.numel()
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
-
- filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
- state_dict = {
- "model.norm.weight": loaded[0]["norm.weight"],
- "model.embed_tokens.weight": torch.cat([loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1),
- "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
+ tensor = permute_for_rope(tensor, n_heads, dim, dim)
+ elif "k_proj" in new_key:
+ tensor = torch.cat(
+ [shard.pop(old_key).view(num_local_key_value_heads, dims_per_head, dim) for shard in loaded_shards],
+ dim=cat_dim,
+ ).reshape(key_value_dim, dim)
+ tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim)
+ elif "v_proj" in new_key:
+ tensor = torch.cat(
+ [shard.pop(old_key).view(num_local_key_value_heads, dims_per_head, dim) for shard in loaded_shards],
+ dim=cat_dim,
+ ).reshape(key_value_dim, dim)
+ elif "input_layernorm" in new_key or "post_attention_layernorm" in new_key:
+ tensor = loaded_shards[0][old_key].clone()
+ elif "model.norm.weight" in new_key:
+ tensor = loaded_shards[0][old_key]
+ else:
+ tensor = torch.cat([shard.pop(old_key) for shard in loaded_shards], dim=cat_dim)
+
+ new_dict[new_key] = tensor
+
+ return new_dict
+
+
+def convert_config(original_config: dict, max_position_embeddings: int):
+ key_mapping = {
+ "hidden_size": "dim",
+ "num_hidden_layers": "n_layers",
+ "intermediate_size": "hidden_dim",
+ "num_attention_heads": "n_heads",
+ "rms_norm_eps": "norm_eps",
}
-
- for k, v in state_dict.items():
- index_dict["weight_map"][k] = filename
- param_count += v.numel()
- torch.save(state_dict, os.path.join(tmp_model_path, filename))
-
- # Write configs
- index_dict["metadata"] = {"total_size": param_count * 2}
- write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
- config = MistralConfig(
- hidden_size=dim,
- intermediate_size=params["hidden_dim"],
- num_attention_heads=params["n_heads"],
- num_hidden_layers=params["n_layers"],
- rms_norm_eps=params["norm_eps"],
- num_key_value_heads=num_key_value_heads,
- vocab_size=vocab_size,
- rope_theta=base,
- max_position_embeddings=max_position_embeddings,
- sliding_window=sliding_window,
+ similar_keys_to_keep = [
+ "head_dim",
+ "vocab_size",
+ ]
+
+ new_config_kwargs = {k: original_config[v] for k, v in key_mapping.items()}
+ new_config_kwargs.update({k: v for k, v in original_config.items() if k in similar_keys_to_keep})
+
+ # These are not always defined depending on `params.json`
+ new_config_kwargs["sliding_window"] = original_config.get("sliding_window", None)
+ new_config_kwargs["num_key_value_heads"] = original_config.get(
+ "n_kv_heads", new_config_kwargs["num_attention_heads"]
)
- config.save_pretrained(tmp_model_path)
+ new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0)
+
+ # This is never provided in `params.json`, we provide it manually
+ new_config_kwargs["max_position_embeddings"] = max_position_embeddings
- # Make space so we can load the model properly now.
- del state_dict
- del loaded
- gc.collect()
+ # This may sometimes be a string in `params.json`
+ if new_config_kwargs["sliding_window"] is not None:
+ new_config_kwargs["sliding_window"] = int(new_config_kwargs["sliding_window"])
- print("Loading the checkpoint in a Mistral model.")
- model = MistralForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
- # Avoid saving this as part of the config.
- del model.config._name_or_path
- model.config.torch_dtype = torch.float16
- print("Saving in the Transformers format.")
+ new_config = MistralConfig(**new_config_kwargs)
+ return new_config
- model.save_pretrained(model_path, safe_serialization=safe_serialization)
- shutil.rmtree(tmp_model_path)
+def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int, modules_are_split: bool):
+ """Convert the model and save it (this implicitly save the config as well)."""
+ params = read_json(os.path.join(input_dir, "params.json"))
+ config = convert_config(params, max_position_embeddings)
+
+ full_state_dict = {}
+ # The model may be split between different files, but a single nn.Module is always fully present in a single file
+ if not modules_are_split:
+ shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
+ for shard_file in shards:
+ original_state_dict = load_file(os.path.join(input_dir, shard_file))
+ new_dict = convert_state_dict(original_state_dict, config)
+ full_state_dict.update(new_dict)
+ # A single nn.Module is split between different checkpoint files
+ else:
+ shards = [file for file in os.listdir(input_dir) if re.match(r"consolidated.\d+.pth", file)]
+ shards = sorted(shards, key=lambda x: int(x.split(".")[1]))
+ loaded_shards = [torch.load(os.path.join(input_dir, file), map_location="cpu") for file in shards]
+ full_state_dict = convert_state_dict_sharded(loaded_shards, config)
+
+ # Load weights into model and resave them
+ with torch.device("meta"):
+ model = MistralForCausalLM(config)
+ model.load_state_dict(full_state_dict, strict=True, assign=True)
+ model.save_pretrained(output_dir)
-def write_tokenizer(tokenizer_path, input_tokenizer_path):
- # Initialize the tokenizer based on the `spm` model
- print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
- tokenizer = tokenizer_class(input_tokenizer_path)
- tokenizer.save_pretrained(tokenizer_path)
+
+def convert_and_write_tokenizer(input_dir: str, output_dir: str):
+ """Convert the tokenizer and save it."""
+ # May have .v3 or .v7 at the end
+ tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0]
+ tokenizer = tokenizer_class(os.path.join(input_dir, tokenizer_file))
+ tokenizer.save_pretrained(output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
- "--input_dir",
+ "input_dir",
help="Location of Mistral weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
- "--model_size",
- choices=["7B", "tokenizer_only"],
- help="'f' models correspond to the finetuned versions, and are specific to the Mistral2 official release. For more details on Mistral2, checkout the original repo: https://huggingface.co/meta-mistral",
+ "output_dir",
+ help="Location to write HF model and tokenizer",
)
parser.add_argument(
- "--output_dir",
- help="Location to write HF model and tokenizer",
+ "--max_position_embeddings",
+ type=int,
+ default=32768,
+ help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).",
)
- parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.")
parser.add_argument(
- "--is_v3", action="store_true", help="Whether the checkpoints correspond to the 3rd version or not."
+ "--modules_are_split",
+ action="store_true",
+ help="If passed, then the weights of a single `nn.Module` are assumed to be split between different files.",
)
+ parser.add_argument(
+ "--tokenizer_only",
+ action="store_true",
+ help="If passed, will only convert the tokenizer.",
+ )
+
args = parser.parse_args()
- spm_path = os.path.join(args.input_dir, "tokenizer.model")
- if args.model_size != "tokenizer_only":
- write_model(
- model_path=args.output_dir,
- input_base_path=args.input_dir,
- model_size=args.model_size,
- safe_serialization=args.safe_serialization,
- tokenizer_path=spm_path,
- is_v3=args.is_v3,
- )
- else:
- write_tokenizer(args.output_dir, spm_path)
+
+ if not args.tokenizer_only:
+ convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split)
+ convert_and_write_tokenizer(args.input_dir, args.output_dir)
if __name__ == "__main__":
diff --git a/src/transformers/quantizers/quantizer_quanto.py b/src/transformers/quantizers/quantizer_quanto.py
index 4027654dc22162..d91019dea15226 100644
--- a/src/transformers/quantizers/quantizer_quanto.py
+++ b/src/transformers/quantizers/quantizer_quanto.py
@@ -26,7 +26,6 @@
from ..utils import (
is_accelerate_available,
is_optimum_quanto_available,
- is_quanto_available,
is_torch_available,
logging,
)
@@ -63,7 +62,7 @@ def post_init(self):
)
def validate_environment(self, *args, **kwargs):
- if not (is_optimum_quanto_available() or is_quanto_available()):
+ if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
)
@@ -91,11 +90,6 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- from quanto import QModuleMixin
not_missing_keys = []
for name, module in model.named_modules():
@@ -122,11 +116,6 @@ def check_quantized_param(
"""
if is_optimum_quanto_available():
from optimum.quanto import QModuleMixin
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- from quanto import QModuleMixin
device_map = kwargs.get("device_map", None)
param_device = kwargs.get("param_device", None)
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 13d9d45f19a88f..af908e48e4b8c4 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -623,9 +623,7 @@ def __init__(
else unwrapped_model.get_base_model().forward
)
forward_params = inspect.signature(model_forward).parameters
- self.model_accepts_loss_kwargs = (
- "loss_kwargs" in forward_params and forward_params["loss_kwargs"].kind == inspect.Parameter.VAR_KEYWORD
- )
+ self.model_accepts_loss_kwargs = any(k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values())
self.neftune_noise_alpha = args.neftune_noise_alpha
@@ -3651,7 +3649,10 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
- loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+ if self.model_accepts_loss_kwargs:
+ loss = self.compute_loss(model, inputs)
+ else:
+ loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index f7e962bec346fb..08d23e0e6a5d41 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -175,7 +175,6 @@
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
- is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index a59a70b959a56a..506c5641fe43ac 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -5006,6 +5006,27 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class IJepaForImageClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class IJepaModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class IJepaPreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class ImageGPTForCausalImageModeling(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py
index 3764f1ee4cef76..101b34182a7309 100755
--- a/src/transformers/utils/fx.py
+++ b/src/transformers/utils/fx.py
@@ -140,6 +140,7 @@ def _generate_supported_model_class_names(
"gptj",
"hiera",
"hubert",
+ "ijepa",
"layoutlm",
"llama",
"cohere",
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 2ce4bd7bc778da..ec1dbad698466b 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -997,11 +997,13 @@ def is_auto_awq_available():
return _auto_awq_available
-def is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- return _quanto_available
+def is_optimum_quanto_available():
+ # `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto`
+ return _is_optimum_quanto_available
+
+
+def is_compressed_tensors_available():
+ return _compressed_tensors_available
def is_optimum_quanto_available():
diff --git a/tests/models/ijepa/__init__.py b/tests/models/ijepa/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/ijepa/test_modeling_ijepa.py b/tests/models/ijepa/test_modeling_ijepa.py
new file mode 100644
index 00000000000000..27a79bc6724285
--- /dev/null
+++ b/tests/models/ijepa/test_modeling_ijepa.py
@@ -0,0 +1,341 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch IJEPA model."""
+
+import unittest
+
+from transformers import IJepaConfig
+from transformers.testing_utils import (
+ require_accelerate,
+ require_torch,
+ require_torch_accelerator,
+ require_torch_fp16,
+ require_vision,
+ slow,
+ torch_device,
+)
+from transformers.utils import (
+ cached_property,
+ is_torch_available,
+ is_vision_available,
+)
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+ from torch import nn
+
+ from transformers import IJepaForImageClassification, IJepaModel
+
+
+if is_vision_available():
+ from PIL import Image
+
+ from transformers import ViTImageProcessor
+
+
+class IJepaModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=13,
+ image_size=30,
+ patch_size=2,
+ num_channels=3,
+ is_training=True,
+ use_labels=True,
+ hidden_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=37,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ type_sequence_label_size=10,
+ initializer_range=0.02,
+ scope=None,
+ encoder_stride=2,
+ mask_ratio=0.5,
+ attn_implementation="eager",
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.image_size = image_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.is_training = is_training
+ self.use_labels = use_labels
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.scope = scope
+ self.encoder_stride = encoder_stride
+ self.attn_implementation = attn_implementation
+
+ # in IJEPA, the seq length equals the number of patches (we don't add 1 for the [CLS] token)
+ num_patches = (image_size // patch_size) ** 2
+ self.seq_length = num_patches
+ self.mask_ratio = mask_ratio
+ self.num_masks = int(mask_ratio * self.seq_length)
+ self.mask_length = num_patches
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor(
+ [
+ self.batch_size,
+ self.num_channels,
+ self.image_size,
+ self.image_size,
+ ]
+ )
+
+ labels = None
+ if self.use_labels:
+ labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+
+ config = self.get_config()
+
+ return config, pixel_values, labels
+
+ def get_config(self):
+ return IJepaConfig(
+ image_size=self.image_size,
+ patch_size=self.patch_size,
+ num_channels=self.num_channels,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ is_decoder=False,
+ initializer_range=self.initializer_range,
+ encoder_stride=self.encoder_stride,
+ attn_implementation=self.attn_implementation,
+ )
+
+ def create_and_check_model(self, config, pixel_values, labels):
+ model = IJepaModel(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.last_hidden_state.shape,
+ (self.batch_size, self.seq_length, self.hidden_size),
+ )
+
+ def create_and_check_for_image_classification(self, config, pixel_values, labels):
+ config.num_labels = self.type_sequence_label_size
+ model = IJepaForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(pixel_values, labels=labels)
+ self.parent.assertEqual(
+ result.logits.shape,
+ (self.batch_size, self.type_sequence_label_size),
+ )
+
+ # test greyscale images
+ config.num_channels = 1
+ model = IJepaForImageClassification(config)
+ model.to(torch_device)
+ model.eval()
+
+ pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
+ result = model(pixel_values)
+ self.parent.assertEqual(
+ result.logits.shape,
+ (self.batch_size, self.type_sequence_label_size),
+ )
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ pixel_values,
+ labels,
+ ) = config_and_inputs
+ inputs_dict = {"pixel_values": pixel_values}
+ return config, inputs_dict
+
+
+@require_torch
+class IJepaModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ """
+ Here we also overwrite some of the tests of test_modeling_common.py, as IJEPA does not use input_ids, inputs_embeds,
+ attention_mask and seq_length.
+ """
+
+ all_model_classes = (
+ (
+ IJepaModel,
+ IJepaForImageClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {"image-feature-extraction": IJepaModel, "image-classification": IJepaForImageClassification}
+ if is_torch_available()
+ else {}
+ )
+ fx_compatible = True
+
+ test_pruning = False
+ test_resize_embeddings = False
+ test_head_masking = False
+
+ def setUp(self):
+ self.model_tester = IJepaModelTester(self)
+ self.config_tester = ConfigTester(
+ self,
+ config_class=IJepaConfig,
+ has_text_modality=False,
+ hidden_size=37,
+ )
+
+ @unittest.skip(
+ "Since `torch==2.3+cu121`, although this test passes, many subsequent tests have `CUDA error: misaligned address`."
+ "If `nvidia-xxx-cu118` are also installed, no failure (even with `torch==2.3+cu121`)."
+ )
+ def test_multi_gpu_data_parallel_forward(self):
+ super().test_multi_gpu_data_parallel_forward()
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ @unittest.skip(reason="IJEPA does not use inputs_embeds")
+ def test_inputs_embeds(self):
+ pass
+
+ def test_model_get_set_embeddings(self):
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ model = model_class(config)
+ self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
+ x = model.get_output_embeddings()
+ self.assertTrue(x is None or isinstance(x, nn.Linear))
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_image_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ model_name = "jmtzt/ijepa_vith14_1k"
+ model = IJepaModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+# We will verify our results on an image of cute cats
+def prepare_img():
+ image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
+ return image
+
+
+@require_torch
+@require_vision
+class IJepaModelIntegrationTest(unittest.TestCase):
+ @cached_property
+ def default_image_processor(self):
+ return ViTImageProcessor.from_pretrained("jmtzt/ijepa_vith14_1k") if is_vision_available() else None
+
+ @slow
+ def test_inference_no_head(self):
+ model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
+
+ image_processor = self.default_image_processor
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ # verify the last hidden state
+ expected_shape = torch.Size((1, 256, 1280))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.Tensor(
+ [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
+
+ @slow
+ @require_accelerate
+ @require_torch_accelerator
+ @require_torch_fp16
+ def test_inference_fp16(self):
+ r"""
+ A small test to make sure that inference work in half precision without any problem.
+ """
+ model = IJepaModel.from_pretrained(
+ "jmtzt/ijepa_vith14_1k",
+ torch_dtype=torch.float16,
+ device_map="auto",
+ )
+ image_processor = self.default_image_processor
+
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass to make sure inference works in fp16
+ with torch.no_grad():
+ _ = model(pixel_values)
+
+ @slow
+ def test_inference_interpolate_pos_encoding(self):
+ # I-JEPA, similar to ViT models have an `interpolate_pos_encoding` argument in their forward method,
+ # allowing to interpolate the pre-trained position embeddings in order to use
+ # the model on higher resolutions. The DINO model by Facebook AI leverages this
+ # to visualize self-attention on higher resolution images.
+ model = IJepaModel.from_pretrained("jmtzt/ijepa_vith14_1k").to(torch_device)
+
+ image_processor = self.default_image_processor
+ image = prepare_img()
+ inputs = image_processor(images=image, return_tensors="pt")
+ pixel_values = inputs.pixel_values.to(torch_device)
+
+ # forward pass
+ with torch.no_grad():
+ outputs = model(pixel_values, interpolate_pos_encoding=True)
+
+ # verify the logits
+ expected_shape = torch.Size((1, 256, 1280))
+ self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
+
+ expected_slice = torch.tensor(
+ [[-0.0621, -0.0054, -2.7513], [-0.1952, 0.0909, -3.9536], [0.0942, -0.0331, -1.2833]]
+ ).to(torch_device)
+
+ self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py
index 0be960f4a33e6d..a2ea05edce8063 100644
--- a/utils/check_docstrings.py
+++ b/utils/check_docstrings.py
@@ -331,6 +331,7 @@
"IBertModel",
"IdeficsConfig",
"IdeficsProcessor",
+ "IJepaModel",
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageGPTConfig",