-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add DAB-DETR Object detection/segmentation model #30803
base: main
Are you sure you want to change the base?
Add DAB-DETR Object detection/segmentation model #30803
Conversation
Hi @conditionedstimulus, thanks for opening a PR! Just skimming over the modeling files, it looks like all of the modules are copied from, or can be copied from conditional DETR. Are there any architectural changes this model brings? If not, then all we need to do is convert the checkpoints and upload those to the hub such that they can be loaded in ConditionalDETR directly |
Hi Amy, I attached a photo comparing the cross-attention of the decoder in DETR, Conditional DETR, and DAB DETR, as this is the main architectural difference. I copied the code from Conditional DETR because this model is an extension/evolved version of Conditional DETR. I believe it would be cool and useful to include this model in the HF object detection collection. |
@conditionedstimulus Thanks for sharing! OK, seems useful to have this available as an option as part of the DETR family in the library. Feel free to ping me when the PR is ready for review. cc @qubvel for reference |
exactly. If someone want's to add complexity, he can copy past the model or monky patch it, we want to reflect the architecture of the PreTrainedModel as much as possiblew |
For the pretrained checkpoints, if they exist for a certain path we kinda have no choice but to keep it that way! Tho what matters is the intention: we try to remove them, then we try to only have things that change the init and not the forward (ex: create 2 classes instead of 1 with if else) etc 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO only thing left to do:
- refactor the attention to make it as close as Llama/Gemma
- small renaming -> d_model is hidden_size etc. Having a look at modeling llama for the correct standards!
- move the losses to loss_utils or loss_dab_detr.py
Awesome refactoring otherwise! 🔥
|
||
model_type = "dab-detr" | ||
keys_to_ignore_at_inference = ["past_key_values"] | ||
attribute_map = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean can we remove the properties? 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ a lot better thanks!
r"input_proj.weight": r"input_projection.weight", | ||
r"input_proj.bias": r"input_projection.bias", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r"input_proj.weight": r"input_projection.weight", | |
r"input_proj.bias": r"input_projection.bias", | |
r"input_proj.(bias|weight)": r"input_projection.\1", |
r"class_embed.weight": r"class_embed.weight", | ||
r"class_embed.bias": r"class_embed.bias", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r"class_embed.weight": r"class_embed.weight", | |
r"class_embed.bias": r"class_embed.bias", | |
r"class_embed.(bias|weight)": r"class_embed.\1", |
r"transformer.encoder.query_scale.layers.(\d+).weight": r"encoder.query_scale.layers.\1.weight", | ||
r"transformer.encoder.query_scale.layers.(\d+).bias": r"encoder.query_scale.layers.\1.bias", | ||
r"transformer.decoder.bbox_embed.layers.(\d+).weight": r"decoder.bbox_embed.layers.\1.weight", | ||
r"transformer.decoder.bbox_embed.layers.(\d+).bias": r"decoder.bbox_embed.layers.\1.bias", | ||
r"transformer.decoder.norm.weight": r"decoder.layernorm.weight", | ||
r"transformer.decoder.norm.bias": r"decoder.layernorm.bias", | ||
r"transformer.decoder.ref_point_head.layers.(\d+).weight": r"decoder.ref_point_head.layers.\1.weight", | ||
r"transformer.decoder.ref_point_head.layers.(\d+).bias": r"decoder.ref_point_head.layers.\1.bias", | ||
r"transformer.decoder.ref_anchor_head.layers.(\d+).weight": r"decoder.ref_anchor_head.layers.\1.weight", | ||
r"transformer.decoder.ref_anchor_head.layers.(\d+).bias": r"decoder.ref_anchor_head.layers.\1.bias", | ||
r"transformer.decoder.query_scale.layers.(\d+).weight": r"decoder.query_scale.layers.\1.weight", | ||
r"transformer.decoder.query_scale.layers.(\d+).bias": r"decoder.query_scale.layers.\1.bias", | ||
r"transformer.decoder.layers.0.ca_qpos_proj.weight": r"decoder.layers.0.layer.1.cross_attn_query_pos_proj.weight", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment about weight and bias for the rest !
|
||
def convert_old_keys_to_new_keys(state_dict_keys: dict = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing Copied from!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if everything is copied from (tell me if I am wrong!) then we can juste directly use image_processing_detr
instead!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right! I removed DabDetrImageProcessor and used DabDetr as you suggested—it works perfectly.
pos_x = x_embed[:, :, :, None] / dim_tx | ||
|
||
# We use float32 to ensure reproducibility of the original implementation | ||
dim_ty = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no worries!
h = [hidden_dim] * (num_layers - 1) | ||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably create this in the config, we would only have to pass the list of tuples!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I attempted the change, but the config test failed with the error: "Object of type ModuleList is not JSON serializable." So I left it as it was. If you still want the change, should I go ahead and modify the config tests?
Feel free to png me when this is ready for another review |
Hi, thanks! Sorry, I got caught up with other errands and couldn’t work on the code for the past two weeks. I’m back at it now and will ping you once I'm done! |
…sutils, tried to move the MLP layer generation to config but it failed
Hi @ArthurZucker, I could use your review and help here. One test is failing, but it’s unrelated to my model. (TFResnet assert error)
The main issue is that the model isn't performing well. It barely learns, validation loss hardly decreases, metrics improve very little, and the final results are poor. Current fine-tuned model notebook Could you help identify where the issue might be? |
Loss function seems good too. |
Hi @ArthurZucker, I tried to pinpoint why the model isn’t learning but haven't identified the issue yet. I’ve tested various model versions, including:
Any idea? :) |
Will leave @qubvel take the review! 🤗 in general super hard for us to jump through and debug training at this stage, will see if it helps! |
What does this PR do?
Add DAB-DETR Object detection model. Paper: https://arxiv.org/abs/2201.12329
Original code repo: https://github.com/IDEA-Research/DAB-DETR
Fixes # (issue)
[WIP] This model is part of how DETR models have evolved, alongside DN DETR (not part of this PR), to pave the way for newer and better models like Dino and Stable Dino in object detection
Who can review?
@amyeroberts