Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DAB-DETR Object detection/segmentation model #30803

Open
wants to merge 94 commits into
base: main
Choose a base branch
from

Conversation

conditionedstimulus
Copy link
Contributor

What does this PR do?

Add DAB-DETR Object detection model. Paper: https://arxiv.org/abs/2201.12329
Original code repo: https://github.com/IDEA-Research/DAB-DETR

Fixes # (issue)
[WIP] This model is part of how DETR models have evolved, alongside DN DETR (not part of this PR), to pave the way for newer and better models like Dino and Stable Dino in object detection

Who can review?

@amyeroberts

@amyeroberts
Copy link
Collaborator

Hi @conditionedstimulus, thanks for opening a PR!

Just skimming over the modeling files, it looks like all of the modules are copied from, or can be copied from conditional DETR. Are there any architectural changes this model brings? If not, then all we need to do is convert the checkpoints and upload those to the hub such that they can be loaded in ConditionalDETR directly

@conditionedstimulus
Copy link
Contributor Author

Hi @conditionedstimulus, thanks for opening a PR!

Just skimming over the modeling files, it looks like all of the modules are copied from, or can be copied from conditional DETR. Are there any architectural changes this model brings? If not, then all we need to do is convert the checkpoints and upload those to the hub such that they can be loaded in ConditionalDETR directly

Hi Amy,

I attached a photo comparing the cross-attention of the decoder in DETR, Conditional DETR, and DAB DETR, as this is the main architectural difference. I copied the code from Conditional DETR because this model is an extension/evolved version of Conditional DETR. I believe it would be cool and useful to include this model in the HF object detection collection.
Screenshot 2024-05-15 at 22 25 15

@amyeroberts
Copy link
Collaborator

@conditionedstimulus Thanks for sharing! OK, seems useful to have this available as an option as part of the DETR family in the library. Feel free to ping me when the PR is ready for review.

cc @qubvel for reference

@ArthurZucker
Copy link
Collaborator

are you suggesting that I remove any code paths, functions, or configuration variables not related to the pre-trained version? It seems that most users will likely utilize the pre-trained model rather than training one from scratch.

exactly. If someone want's to add complexity, he can copy past the model or monky patch it, we want to reflect the architecture of the PreTrainedModel as much as possiblew

@ArthurZucker
Copy link
Collaborator

For the pretrained checkpoints, if they exist for a certain path we kinda have no choice but to keep it that way! Tho what matters is the intention: we try to remove them, then we try to only have things that change the init and not the forward (ex: create 2 classes instead of 1 with if else) etc 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO only thing left to do:

  • refactor the attention to make it as close as Llama/Gemma
  • small renaming -> d_model is hidden_size etc. Having a look at modeling llama for the correct standards!
  • move the losses to loss_utils or loss_dab_detr.py

Awesome refactoring otherwise! 🔥


model_type = "dab-detr"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean can we remove the properties? 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ a lot better thanks!

Comment on lines 36 to 37
r"input_proj.weight": r"input_projection.weight",
r"input_proj.bias": r"input_projection.bias",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
r"input_proj.weight": r"input_projection.weight",
r"input_proj.bias": r"input_projection.bias",
r"input_proj.(bias|weight)": r"input_projection.\1",

Comment on lines 39 to 40
r"class_embed.weight": r"class_embed.weight",
r"class_embed.bias": r"class_embed.bias",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
r"class_embed.weight": r"class_embed.weight",
r"class_embed.bias": r"class_embed.bias",
r"class_embed.(bias|weight)": r"class_embed.\1",

Comment on lines 44 to 56
r"transformer.encoder.query_scale.layers.(\d+).weight": r"encoder.query_scale.layers.\1.weight",
r"transformer.encoder.query_scale.layers.(\d+).bias": r"encoder.query_scale.layers.\1.bias",
r"transformer.decoder.bbox_embed.layers.(\d+).weight": r"decoder.bbox_embed.layers.\1.weight",
r"transformer.decoder.bbox_embed.layers.(\d+).bias": r"decoder.bbox_embed.layers.\1.bias",
r"transformer.decoder.norm.weight": r"decoder.layernorm.weight",
r"transformer.decoder.norm.bias": r"decoder.layernorm.bias",
r"transformer.decoder.ref_point_head.layers.(\d+).weight": r"decoder.ref_point_head.layers.\1.weight",
r"transformer.decoder.ref_point_head.layers.(\d+).bias": r"decoder.ref_point_head.layers.\1.bias",
r"transformer.decoder.ref_anchor_head.layers.(\d+).weight": r"decoder.ref_anchor_head.layers.\1.weight",
r"transformer.decoder.ref_anchor_head.layers.(\d+).bias": r"decoder.ref_anchor_head.layers.\1.bias",
r"transformer.decoder.query_scale.layers.(\d+).weight": r"decoder.query_scale.layers.\1.weight",
r"transformer.decoder.query_scale.layers.(\d+).bias": r"decoder.query_scale.layers.\1.bias",
r"transformer.decoder.layers.0.ca_qpos_proj.weight": r"decoder.layers.0.layer.1.cross_attn_query_pos_proj.weight",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment about weight and bias for the rest !

Comment on lines 127 to 128

def convert_old_keys_to_new_keys(state_dict_keys: dict = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing Copied from!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if everything is copied from (tell me if I am wrong!) then we can juste directly use image_processing_detr instead!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right! I removed DabDetrImageProcessor and used DabDetr as you suggested—it works perfectly.

src/transformers/models/dab_detr/modeling_dab_detr.py Outdated Show resolved Hide resolved
pos_x = x_embed[:, :, :, None] / dim_tx

# We use float32 to ensure reproducibility of the original implementation
dim_ty = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries!

Comment on lines +955 to +956
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably create this in the config, we would only have to pass the list of tuples!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted the change, but the config test failed with the error: "Object of type ModuleList is not JSON serializable." So I left it as it was. If you still want the change, should I go ahead and modify the config tests?

@ArthurZucker
Copy link
Collaborator

Feel free to png me when this is ready for another review

@conditionedstimulus
Copy link
Contributor Author

Feel free to png me when this is ready for another review

Hi, thanks! Sorry, I got caught up with other errands and couldn’t work on the code for the past two weeks. I’m back at it now and will ping you once I'm done!

@conditionedstimulus
Copy link
Contributor Author

conditionedstimulus commented Nov 3, 2024

IMO only thing left to do:

  • refactor the attention to make it as close as Llama/Gemma
  • small renaming -> d_model is hidden_size etc. Having a look at modeling llama for the correct standards!
  • move the losses to loss_utils or loss_dab_detr.py

Awesome refactoring otherwise! 🔥

Hi @ArthurZucker,

I could use your review and help here.

One test is failing, but it’s unrelated to my model. (TFResnet assert error)

  • I refactored the attention mechanism, and it looks solid—I think it’s much closer to mllama now.
  • I also did some renaming, do you suggest any further changes?
  • Additionally, I moved the loss function as you recommended.

The main issue is that the model isn't performing well. It barely learns, validation loss hardly decreases, metrics improve very little, and the final results are poor.
I suspect the problem might lie in the loss function or the attention, with the loss function being my main guess.
While I implemented it following the approach in Conditional DETR, I’m not entirely sure it’s correct.

Current fine-tuned model notebook
Previous model notebook when the results were good

Could you help identify where the issue might be?
Thank you!

@conditionedstimulus
Copy link
Contributor Author

Loss function seems good too.

@conditionedstimulus
Copy link
Contributor Author

conditionedstimulus commented Nov 6, 2024

Hi @ArthurZucker,

I tried to pinpoint why the model isn’t learning but haven't identified the issue yet. I’ve tested various model versions, including:

Any idea? :)
Thanks for your help in finding the issue!

@ArthurZucker
Copy link
Collaborator

Will leave @qubvel take the review! 🤗 in general super hard for us to jump through and debug training at this stage, will see if it helps!

@ArthurZucker ArthurZucker removed their request for review November 19, 2024 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants