Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Molmo (7B-D, 7B-O, 70B) #33962

Open
wants to merge 145 commits into
base: main
Choose a base branch
from
Open

Add Molmo (7B-D, 7B-O, 70B) #33962

wants to merge 145 commits into from

Conversation

molbap
Copy link
Contributor

@molbap molbap commented Oct 4, 2024

What does this PR do?

As mentioned in issue #33710 , this is a draft to add support for Molmo natively in transformers.
It is also using the new modular framework introduced in #33248 .

Molmo has several existing variants:

  • MolmoE, a mixture of experts multimodal model, which is not covered in this PR but will be in a follow-up one.
  • Molmo-7B-D, based on Qwen2 + CLIP.
  • Molmo-7B-O, based on a yet to be released Olmo model, and CLIP.
  • Molmo-70B, a scaled up version.

The last three models share the same modeling, and thus will be covered by this PR.

Relative to the modular framework:

Choose a base model that's as close as possible from the one you're porting.

In my case, I'm using Llava as a reference. The differences I identify at a glance are the 2d pooling,

Figure out the differences.

Some differences will be a complete modification of the original module, in that case, all have to be redefined.

class MolmoMultiModalProjector(LlavaMultiModalProjector):
    def __init__(self, config: MolmoConfig):
        super().__init__()
        self.linear_1 = nn.Linear(
            config.vision_config.hidden_size,
            config.text_config.intermediate_size // 2,
            bias=False,
            )
        self.linear_2 = nn.Linear(
            config.text_config.intermediate_size // 2,
            config.text_config.hidden_size,
            bias=False,
            )
        self.linear_3 = nn.Linear(
            config.vision_config.hidden_size,
            config.text_config.intermediate_size // 2,
            bias=False,
            )
    
    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        intermediate_states = self.linear_3(image_features)
        hidden_states = self.linear_2(hidden_states, intermediate_states)
        return hidden_states

Some differences will be very tiny. For instance, some layers might be the same, but initialized with a different configuration key.
For instance, the position embeddings are slightly different.

class MolmoVisionEmbeddings(CLIPVisionEmbeddings):
    def __init__(self, config):
        super().__init__()
        self.position_embedding = nn.Embedding(config.num_image_positions, config.hidden_size)

Preserving inheritance across model components renames.

For instance, the code above will trigger

python utils/modular_model_converter.py --files_to_parse src/transformers/models/molmo/modular_molmo.py  --old_model_name="Llava" --new_model_name="Molmo"

> ValueError: Unable to find dependencies for CLIPVisionEmbeddings in transformers.models.clip.modeling_clip. Here are the dependencies found: {'molmo_loss': {'contrastive_loss'}, 'MOLMOVisionModelOutput': {'ModelOutput'}, 'MOLMOTextModelOutput': {'ModelOutput'}, 'MOLMOOutput': {'Mod
elOutput'}, 'MOLMOVisionEmbeddings': {'nn.Module'},

Because the supported pattern is currently searching for a caps-based model name. However, using modular is very promising and makes for a much smaller modeling file to review.

I'll write down hurdles encountered here for future reference so that adding multimodal models to transformers ends up being a breeze.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow looks super nice! Will finish #33859 asap to let you continue!

@molbap
Copy link
Contributor Author

molbap commented Oct 8, 2024

Still seeing some duplicate imports in the modeling code:

from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_molmo import MolmoConfig


if is_flash_attn_2_available():
    from ...modeling_flash_attention_utils import _flash_attention_forward


from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from ...utils import (
    ModelOutput,
    is_flash_attn_2_available,
    torch_int,
)
from .configuration_molmo import MOLMOConfig, MOLMOVisionConfig

One quick&dirty solution would be to do a pass on the imports once the transformer in modular has finished, so that imports from various modules get merged and normalized to the most likely - but there's also some capitalized (wrong) model names that remain as well, strangely, like MOLMOEncoder where we should get MolmoEncoder

class MolmoVisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        self.embeddings = MolmoVisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = MOLMOEncoder(config)  #  wut 
        self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=True)

getting there however!

@ArthurZucker
Copy link
Collaborator

Do you need a review? 🤗

@d-rau
Copy link

d-rau commented Oct 15, 2024

Maybe a bit pre-mature but when using the script to convert the model to hf I got missmatch issues here:

q_proj, k_proj, v_proj = torch.split(fused_qkv, fused_dims, 0)

@moha23
Copy link

moha23 commented Jan 7, 2025

Hi @molbap, I am using the vision_backbone from this model. The processor creates varying number of crops of different images, so the processed 'images' tensor has dimension (batch_size, num_crops, num_patch, n_pixels) with varying num_crops, even with images of same resolution (after resizing). So anything with batch_size>1 does not work. Is there a way to fix the number of crops? Passing self.max_crops value doesn't seem to help.

@molbap
Copy link
Contributor Author

molbap commented Jan 7, 2025

@moha23 , that's surprising - I resumed working on this after new year's break, but what you described should not happen as we pad with _pad_for_batching, then the filtering in processing should get rid of it. Are you sure you are using this branch's version?
Either way I'll get this merged as soon as possible, if you still find this problem please do flag it but I'm surprised as I'm testing with batches with variable sizes and it works well - well it did work well, now I am figuring out why #35235 broke generations for this specific model 🤔

@moha23
Copy link

moha23 commented Jan 8, 2025

@moha23 , that's surprising - I resumed working on this after new year's break, but what you described should not happen as we pad with _pad_for_batching, then the filtering in processing should get rid of it. Are you sure you are using this branch's version? Either way I'll get this merged as soon as possible, if you still find this problem please do flag it but I'm surprised as I'm testing with batches with variable sizes and it works well - well it did work well, now I am figuring out why #35235 broke generations for this specific model 🤔

Have not used this branch, but the latest stable version. Thanks! Will check what changes need to be made.

@molbap
Copy link
Contributor Author

molbap commented Jan 8, 2025

The cohere interface had changed in the attentions refactor, I included it cherry-picking from #35359 and now seems to run fine :) @yonigozlan can you take a look at the fast image processor? It's not inline with #35069 but it's torch/torchvision and indeed faster. @ArthurZucker model should be ready now, ping so it goes back up in your radar!

Copy link
Member

@yonigozlan yonigozlan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me for the fast image processors! I just left two questions below.
And no need to be inline with #35069 for now, as it is not set in stone. I can go over this processor again once #35069 is merged and make the necessary modifications.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks :)

Comment on lines +21 to +22
The Molmo model was proposed in [Molmo and PixMo: Open Weights and Open Data for State-of-the-Art Multimodal Models
]([https://arxiv.org/abs/2409.17146]) by Matt Deitke, Christopher Clark, Sangho Lee, Rohun Tripathi, Yue Yang, Jae Sung Park, Mohammadreza Salehi, Niklas Muennighoff, Kyle Lo, Luca Soldaini, Jiasen Lu, Taira Anderson, Erin Bransom, Kiana Ehsani, Huong Ngo, YenSung Chen, Ajay Patel, Mark Yatskar, Chris Callison-Burch, Andrew Head, Rose Hendrix, Favyen Bastani, Eli VanderBilt, Nathan Lambert, Yvonne Chou, Arnavi Chheda, Jenna Sparks, Sam Skjonsberg, Michael Schmitz, Aaron Sarnat, Byron Bischoff, Pete Walsh, Chris Newell, Piper Wolters, Tanmay Gupta, Kuo-Hao Zeng, Jon Borchardt, Dirk Groeneveld, Jen Dumas, Crystal Nam, Sophie Lebrecht, Caitlin Wittlif, Carissa Schoenick, Oscar Michel, Ranjay Krishna, Luca Weihs, Noah A. Smith, Hannaneh Hajishirzi, Ross Girshick, Ali Farhadi, Aniruddha Kembhavi.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The Molmo model was proposed in [Molmo and PixMo: Open Weights and Open Data for State-of-the-Art Multimodal Models
]([https://arxiv.org/abs/2409.17146]) by Matt Deitke, Christopher Clark, Sangho Lee, Rohun Tripathi, Yue Yang, Jae Sung Park, Mohammadreza Salehi, Niklas Muennighoff, Kyle Lo, Luca Soldaini, Jiasen Lu, Taira Anderson, Erin Bransom, Kiana Ehsani, Huong Ngo, YenSung Chen, Ajay Patel, Mark Yatskar, Chris Callison-Burch, Andrew Head, Rose Hendrix, Favyen Bastani, Eli VanderBilt, Nathan Lambert, Yvonne Chou, Arnavi Chheda, Jenna Sparks, Sam Skjonsberg, Michael Schmitz, Aaron Sarnat, Byron Bischoff, Pete Walsh, Chris Newell, Piper Wolters, Tanmay Gupta, Kuo-Hao Zeng, Jon Borchardt, Dirk Groeneveld, Jen Dumas, Crystal Nam, Sophie Lebrecht, Caitlin Wittlif, Carissa Schoenick, Oscar Michel, Ranjay Krishna, Luca Weihs, Noah A. Smith, Hannaneh Hajishirzi, Ross Girshick, Ali Farhadi, Aniruddha Kembhavi.
The Molmo model was proposed in [Molmo and PixMo: Open Weights and Open Data for State-of-the-Art Multimodal Models]([https://arxiv.org/abs/2409.17146]) by Matt Deitke, Christopher Clark, Sangho Lee, Rohun Tripathi, Yue Yang, Jae Sung Park, Mohammadreza Salehi, Niklas Muennighoff, Kyle Lo, Luca Soldaini, Jiasen Lu, Taira Anderson, Erin Bransom, Kiana Ehsani, Huong Ngo, YenSung Chen, Ajay Patel, Mark Yatskar, Chris Callison-Burch, Andrew Head, Rose Hendrix, Favyen Bastani, Eli VanderBilt, Nathan Lambert, Yvonne Chou, Arnavi Chheda, Jenna Sparks, Sam Skjonsberg, Michael Schmitz, Aaron Sarnat, Byron Bischoff, Pete Walsh, Chris Newell, Piper Wolters, Tanmay Gupta, Kuo-Hao Zeng, Jon Borchardt, Dirk Groeneveld, Jen Dumas, Crystal Nam, Sophie Lebrecht, Caitlin Wittlif, Carissa Schoenick, Oscar Michel, Ranjay Krishna, Luca Weihs, Noah A. Smith, Hannaneh Hajishirzi, Ross Girshick, Ali Farhadi, Aniruddha Kembhavi.

Comment on lines +28 to +29
*Today's most advanced multimodal models remain proprietary. The strongest open-weight models rely heavily on synthetic data from proprietary VLMs to achieve good performance, effectively distilling these closed models into open ones. As a result, the community is still missing foundational knowledge about how to build performant VLMs from scratch. We present Molmo, a new family of VLMs that are state-of-the-art in their class of openness. Our key innovation is a novel, highly detailed image caption dataset collected entirely from human annotators using speech-based descriptions. To enable a wide array of user interactions, we also introduce a diverse dataset mixture for fine-tuning that includes in-the-wild Q&A and innovative 2D pointing data. The success of our approach relies on careful choices for the model architecture details, a well-tuned training pipeline, and, most critically, the quality of our newly collected datasets, all of which will be released. The best-in-class 72B model within the Molmo family not only outperforms others in the class of open weight and data models but also compares favorably against proprietary systems like GPT-4o, Claude 3.5, and Gemini 1.5 on both academic benchmarks and human evaluation.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*Today's most advanced multimodal models remain proprietary. The strongest open-weight models rely heavily on synthetic data from proprietary VLMs to achieve good performance, effectively distilling these closed models into open ones. As a result, the community is still missing foundational knowledge about how to build performant VLMs from scratch. We present Molmo, a new family of VLMs that are state-of-the-art in their class of openness. Our key innovation is a novel, highly detailed image caption dataset collected entirely from human annotators using speech-based descriptions. To enable a wide array of user interactions, we also introduce a diverse dataset mixture for fine-tuning that includes in-the-wild Q&A and innovative 2D pointing data. The success of our approach relies on careful choices for the model architecture details, a well-tuned training pipeline, and, most critically, the quality of our newly collected datasets, all of which will be released. The best-in-class 72B model within the Molmo family not only outperforms others in the class of open weight and data models but also compares favorably against proprietary systems like GPT-4o, Claude 3.5, and Gemini 1.5 on both academic benchmarks and human evaluation.
*
*Today's most advanced multimodal models remain proprietary. The strongest open-weight models rely heavily on synthetic data from proprietary VLMs to achieve good performance, effectively distilling these closed models into open ones. As a result, the community is still missing foundational knowledge about how to build performant VLMs from scratch. We present Molmo, a new family of VLMs that are state-of-the-art in their class of openness. Our key innovation is a novel, highly detailed image caption dataset collected entirely from human annotators using speech-based descriptions. To enable a wide array of user interactions, we also introduce a diverse dataset mixture for fine-tuning that includes in-the-wild Q&A and innovative 2D pointing data. The success of our approach relies on careful choices for the model architecture details, a well-tuned training pipeline, and, most critically, the quality of our newly collected datasets, all of which will be released. The best-in-class 72B model within the Molmo family not only outperforms others in the class of open weight and data models but also compares favorably against proprietary systems like GPT-4o, Claude 3.5, and Gemini 1.5 on both academic benchmarks and human evaluation.*


Tips:

- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
- We recommend calling `processor.tokenizer.padding_side = "left"` for batched generation because it leads to more accurate results.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool! a few comments left and good to go!

@@ -334,6 +335,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
elif type(config) in PROCESSOR_MAPPING:
return PROCESSOR_MAPPING[type(config)].from_pretrained(pretrained_model_name_or_path, **kwargs)

print("BUT WHY", processor_class)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😉 to remove!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, some debugging struggles scars left



# swiglu activation
class MolmoSwiGLU(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also just put this in the text MLP 😉
and let's not use single letter variables

Comment on lines +610 to +630
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks standard! What's the diff ?

return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def pooling_eager_attention_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is exactly the same as bamba, gemma etc, might need to import xxxx as pooling_eager_attention_forward !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right, I moved the q/kv states to the forward to yes we can uniformize this!

Comment on lines +900 to +906
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**0.5
self.is_causal = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't think we need all this in the forward!

Comment on lines +856 to +865
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be imported!

Comment on lines +1080 to +1081
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember, why is this explicit if it is not changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that's weird indeed 🤔 maybe I wanted to override it so it does nothing? will remove

logger = logging.get_logger(__name__)


def get_resize_output_image_size(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from? (since this file is not modular generated!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants