Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Nov 7, 2024
1 parent 3650cdf commit 2363b99
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6309,9 +6309,9 @@
AltCLIPVisionModel,
)
from .models.aria import (
AriaTextForCausalLM,
AriaForConditionalGeneration,
AriaPreTrainedModel,
AriaTextForCausalLM,
AriaTextModel,
)
from .models.audio_spectrogram_transformer import (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/aria/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
pass
else:
from .modeling_aria import (
AriaTextForCausalLM,
AriaForConditionalGeneration,
AriaPreTrainedModel,
AriaTextForCausalLM,
AriaTextModel,
)
from .processing_aria import AriaProcessor
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/aria/convert_aria_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ def convert_aria_llama_to_hf(text_model_id, vision_model_id, output_hub_path, ol

config = AutoConfig.from_pretrained(text_model_id)
config.vision_config.hidden_size = 1152
config.vision_config.attention_heads=16
config.vision_config.attention_heads = 16
config.pad_token_id = 2
config.image_token_index = 9
config.auto_map = {
"AutoConfig": "modeling_aria.AriaConfig",
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration"
"AutoModelForCausalLM": "modeling_aria.AriaForConditionalGeneration",
}

config.pad_token_id = 32001
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/aria/image_processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from ...tokenization_utils import (
TensorType,
)
from ...utils.import_utils import is_vision_available, is_torch_available
from ...utils.import_utils import is_torch_available, is_vision_available


if is_vision_available():
from PIL import Image, ImageOps
from PIL import Image

if is_torch_available():
import torch
Expand Down
69 changes: 42 additions & 27 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from ...image_processing_utils import BaseImageProcessor, select_best_resolution
from ...image_transforms import (
convert_to_rgb,
pad,
resize,
to_channel_dimension_format,
pad,
)
from ...image_utils import (
ChannelDimension,
Expand All @@ -23,18 +23,16 @@
to_numpy_array,
)
from ...modeling_utils import PreTrainedModel
from ...processing_utils import ProcessorMixin, ProcessingKwargs, Unpack
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import (
PaddingStrategy,
PreTokenizedInput,
TensorType,
TextInput,
TruncationStrategy,
)
from ...utils import (
logging,
)
from ...utils.import_utils import is_vision_available, is_torch_available
from ...utils.import_utils import is_torch_available, is_vision_available
from ..auto import CONFIG_MAPPING, AutoModel, AutoModelForCausalLM, AutoTokenizer
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
Expand Down Expand Up @@ -89,6 +87,7 @@ def sequential_gemm(token_states, expert_weights, tokens_per_expert):
output[start:end] = out
return output


if os.environ.get("USE_GROUPED_GEMM", "1") == "0":
logger.warning("environment variable USE_GROUPED_GEMM is set to 0, using sequential GEMM")
experts_gemm = sequential_gemm
Expand All @@ -99,6 +98,7 @@ def sequential_gemm(token_states, expert_weights, tokens_per_expert):
else:
from grouped_gemm.ops import gmm as experts_gemm


class AriaTextConfig(LlamaConfig):
"""
Configuration class for Aria language model.
Expand Down Expand Up @@ -321,7 +321,7 @@ def __init__(
self.layer_norm = nn.LayerNorm(self.in_features)
self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)

def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor]=None):
def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
"""
Forward pass of the Projector module.
Expand All @@ -335,7 +335,9 @@ def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tens
batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]

if num_patches not in self.patch_to_query_dict.keys():
raise KeyError(f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}.")
raise KeyError(
f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
)
query_num = self.patch_to_query_dict[num_patches]

queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
Expand All @@ -350,6 +352,7 @@ def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tens

return out


# Copied from models.llava_next.image_processing_llava_next.py
def divide_to_patches(image: np.array, patch_size: int, input_data_format) -> List[np.array]:
"""
Expand Down Expand Up @@ -454,12 +457,27 @@ def __init__(
self.image_mean = image_mean
self.image_std = image_std
if split_ratio is None:
self.split_ratio = [
(1, 2), (1, 3), (1, 4), (1, 5), (1, 6),
(1, 7), (1, 8), (2, 4), (2, 3), (2, 2),
(2, 1), (3, 1), (3, 2), (4, 1), (4, 2),
(5, 1), (6, 1), (7, 1), (8, 1),
]
self.split_ratio = [
(1, 2),
(1, 3),
(1, 4),
(1, 5),
(1, 6),
(1, 7),
(1, 8),
(2, 4),
(2, 3),
(2, 2),
(2, 1),
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(5, 1),
(6, 1),
(7, 1),
(8, 1),
]
else:
self.split_ratio = split_ratio

Expand Down Expand Up @@ -542,7 +560,7 @@ def preprocess(

if do_normalize:
crop_image_padded = self.normalize(crop_image_padded, self.image_mean, self.image_std)

# Switch to rgb channel first
crop_image_padded = np.transpose(crop_image_padded, (2, 0, 1))
pixel_values.append(crop_image_padded)
Expand Down Expand Up @@ -606,20 +624,22 @@ def get_image_patches(
]
return patches


class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"truncation": None,
"max_length": None,
"padding": False,
"truncation": None,
"max_length": None,
},
"images_kwargs": {
"max_image_size": 980,
"split_image": False,
"max_image_size": 980,
"split_image": False,
},
"return_tensors": TensorType.PYTORCH,
"return_tensors": TensorType.PYTORCH,
}


class AriaProcessor(ProcessorMixin):
"""
AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
Expand Down Expand Up @@ -1151,9 +1171,7 @@ def __init__(self, config: AriaConfig):
self.vision_tower = AutoModel.from_config(
config.vision_config, attn_implementation=config._attn_implementation
)
self.multi_modal_projector = AriaProjector(
config
)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
Expand Down Expand Up @@ -1326,10 +1344,7 @@ def forward(
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.text_config.vocab_size,
**loss_kwargs
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **loss_kwargs
)

if not return_dict:
Expand Down
17 changes: 9 additions & 8 deletions src/transformers/models/aria/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_aria.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
import inspect
from typing import Dict, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import (
ImageInput,
)
from ...processing_utils import ProcessorMixin, ProcessingKwargs, Unpack
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils import (
PreTokenizedInput,
TensorType,
Expand All @@ -24,20 +23,22 @@

logger = logging.get_logger(__name__)


class AriaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
"truncation": None,
"max_length": None,
"padding": False,
"truncation": None,
"max_length": None,
},
"images_kwargs": {
"max_image_size": 980,
"split_image": False,
"max_image_size": 980,
"split_image": False,
},
"return_tensors": TensorType.PYTORCH,
"return_tensors": TensorType.PYTORCH,
}


class AriaProcessor(ProcessorMixin):
"""
AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
Expand Down
4 changes: 1 addition & 3 deletions tests/models/aria/test_modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,7 @@ def test_small_model_integration_test_llama_batched_regression(self):
model_id = "rhymes-ai/Aria"

# Multi-image & multi-prompt (e.g. 3 images and 2 prompts now fails with SDPA, this tests if "eager" works as before)
model = AriaForConditionalGeneration.from_pretrained(
model_id, load_in_4bit=True, attn_implementation="eager"
)
model = AriaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True, attn_implementation="eager")
processor = AutoProcessor.from_pretrained(model_id, pad_token="<pad>")

prompts = [
Expand Down

0 comments on commit 2363b99

Please sign in to comment.