Skip to content

Commit

Permalink
apply modular
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Dec 5, 2024
1 parent 0d22212 commit df94db3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
18 changes: 9 additions & 9 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,9 +805,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -890,9 +890,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down Expand Up @@ -1014,9 +1014,9 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

if position_embeddings is None:
logger.warning_once(
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/got_ocr2/processing_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,14 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
format = output_kwargs["text_kwargs"].pop("format", False)
num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens", 256)
format = output_kwargs["text_kwargs"].pop("format")
num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens")
box = output_kwargs["images_kwargs"].pop("box", [None])
color = output_kwargs["images_kwargs"].pop("color", None)
multi_page = output_kwargs["images_kwargs"].pop("multi_page", False)
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches", False)
min_patches = output_kwargs["images_kwargs"].pop("min_patches", 1)
max_patches = output_kwargs["images_kwargs"].pop("max_patches", 6)
multi_page = output_kwargs["images_kwargs"].pop("multi_page")
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches")
min_patches = output_kwargs["images_kwargs"].pop("min_patches")
max_patches = output_kwargs["images_kwargs"].pop("max_patches")

self._check_call_arguments(images, box, color, multi_page, crop_to_patches)
images, text, box, color = self._make_list_of_inputs(images, text, box, color, multi_page)
Expand Down

0 comments on commit df94db3

Please sign in to comment.