Skip to content

Commit

Permalink
additional changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jadechoghari committed Dec 27, 2024
1 parent d762677 commit 03bb2cc
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 42 deletions.
10 changes: 2 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5624,9 +5624,7 @@
RTDetrConfig,
RTDetrResNetConfig,
)
from .models.rt_detr_v2 import (
RtDetrV2Config
)
from .models.rt_detr_v2 import RtDetrV2Config
from .models.rwkv import RwkvConfig
from .models.sam import (
SamConfig,
Expand Down Expand Up @@ -7809,11 +7807,7 @@
RTDetrResNetBackbone,
RTDetrResNetPreTrainedModel,
)
from .models.rt_detr_v2 import (
RtDetrV2ForObjectDetection,
RtDetrV2Model,
RtDetrV2PreTrainedModel
)
from .models.rt_detr_v2 import RtDetrV2ForObjectDetection, RtDetrV2Model, RtDetrV2PreTrainedModel
from .models.rwkv import (
RwkvForCausalLM,
RwkvModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from PIL import Image
from torchvision import transforms

from transformers import RTDetrConfig, RTDetrForObjectDetection, RTDetrImageProcessor
from transformers import RTDetrImageProcessor
from transformers.models.rt_detr_v2.modular_rt_detr_v2 import RtDetrV2Config, RtDetrV2ForObjectDetection
from transformers.utils import logging

Expand Down Expand Up @@ -578,7 +578,6 @@ def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub

# finally, create HuggingFace model and load state dict
model = RtDetrV2ForObjectDetection(config)
breakpoint()
model.load_state_dict(state_dict)
model.eval()

Expand Down Expand Up @@ -610,7 +609,6 @@ def convert_rt_detr_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub
# Pass image by the model
outputs = model(pixel_values)

breakpoint()
if model_name == "rtdetr_v2_r18vd":
expected_slice_logits = torch.tensor(
[[-3.7045, -5.1913, -6.1787], [-4.0106, -9.3450, -5.2043], [-4.1287, -4.7463, -5.8634]]
Expand Down
52 changes: 22 additions & 30 deletions src/transformers/models/rt_detr_v2/modular_rt_detr_v2.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
from ..rt_detr.modeling_rt_detr import (
RTDetrDecoderLayer, RTDetrModelOutput, RTDetrObjectDetectionOutput, RTDetrHybridEncoder,
RTDetrEncoderLayer, RTDetrConvEncoder, RTDetrMLPPredictionHead, RTDetrDecoderOutput,
RTDetrMultiscaleDeformableAttention, MultiScaleDeformableAttentionFunction, get_contrastive_denoising_training_group,
inverse_sigmoid, RTDetrDecoder, RTDetrModel, RTDetrForObjectDetection
)
from ...modeling_utils import PreTrainedModel
from ..rt_detr.configuration_rt_detr import RTDetrConfig
from ...activations import ACT2FN
import math
import os
import warnings
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ...utils import is_torch_cuda_available
import torch.nn as nn
from ...modeling_utils import PreTrainedModel
from ..rt_detr.configuration_rt_detr import RTDetrConfig
from ..rt_detr.modeling_rt_detr import (
MultiScaleDeformableAttentionFunction,
RTDetrDecoder,
RTDetrDecoderLayer,
RTDetrForObjectDetection,
RTDetrModel,
RTDetrMultiscaleDeformableAttention,
)


class RtDetrV2Config(RTDetrConfig):
model_type = "rt_detr_v2"

def __init__(self,
def __init__(self,
decoder_n_levels=3, # default value
decoder_offset_scale=0.5, # default value
**super_kwargs):
# init the base RTDetrConfig class with additional parameters
super().__init__(**super_kwargs)

# add the new attributes with the given values or defaults
self.decoder_n_levels = decoder_n_levels
self.decoder_offset_scale = decoder_offset_scale
Expand Down Expand Up @@ -129,7 +122,7 @@ def __init__(self, config: RtDetrV2Config):
num_heads=config.decoder_attention_heads,
n_points=config.decoder_n_points
)

# V2-specific attributes
self.offset_scale = config.decoder_offset_scale
# Initialize n_points list and scale
Expand All @@ -152,10 +145,10 @@ def _reset_parameters(self):
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1

with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))

nn.init.constant_(self.attention_weights.weight.data, 0.0)
nn.init.constant_(self.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(self.value_proj.weight.data)
Expand Down Expand Up @@ -191,12 +184,12 @@ def forward(
if attention_mask is not None:
value = value.masked_fill(~attention_mask[..., None], float(0))
value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)

# V2-specific sampling offsets shape
sampling_offsets = self.sampling_offsets(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points, 2
)

attention_weights = self.attention_weights(hidden_states).view(
batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
)
Expand Down Expand Up @@ -253,7 +246,6 @@ def __init__(self, config: RtDetrV2Config):
self.layers = nn.ModuleList([RtDetrV2DecoderLayer(config) for _ in range(config.decoder_layers)])


# could make better use of inheritence
class RtDetrV2Model(RTDetrModel):
def __init__(self, config: RtDetrV2Config):
super().__init__(config)
Expand All @@ -277,7 +269,7 @@ class RtDetrV2PreTrainedModel(PreTrainedModel):
def __init__(self, RTDetrV2Config):
super().__init__(config)
self.model = RTDetrV2Model(config)

def _init_weights(self, module):
# this could be simplified like calling the base class function first then adding new stuff
"""Initalize the weights"""
Expand All @@ -292,10 +284,10 @@ def _init_weights(self, module):
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1

with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))

nn.init.constant_(self.attention_weights.weight.data, 0.0)
nn.init.constant_(self.attention_weights.bias.data, 0.0)
nn.init.xavier_uniform_(self.value_proj.weight.data)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/rtdetrv2/test_modeling_rt_detr_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from parameterized import parameterized

from transformers import (
RtDetrV2Config,
RTDetrImageProcessor,
RtDetrV2Config,
is_torch_available,
is_vision_available,
)
Expand Down

0 comments on commit 03bb2cc

Please sign in to comment.