Skip to content

Commit

Permalink
update to new init
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Nov 25, 2024
1 parent b36ccb3 commit 4007fb2
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 71 deletions.
63 changes: 8 additions & 55 deletions src/transformers/models/got_ocr2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,66 +13,19 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)


_import_structure = {
"configuration_got_ocr2": ["GotOcr2Config", "GotOcr2VisionConfig"],
"processing_got_ocr2": ["GotOcr2Processor"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_got_ocr2"] = ["GotOcr2ImageProcessor"]


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_got_ocr2"] = [
"GotOcr2ForConditionalGeneration",
"GotOcr2Model",
"GotOcr2PreTrainedModel",
]
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig
from .processing_got_ocr2 import GotOcr2Processor

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_got_ocr2 import GotOcr2ImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_got_ocr2 import (
GotOcr2ForConditionalGeneration,
GotOcr2Model,
GotOcr2PreTrainedModel,
)
from .configuration_got_ocr2 import *
from .image_processing_got_ocr2 import *
from .modeling_got_ocr2 import *
from .processing_got_ocr2 import *


else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
3 changes: 3 additions & 0 deletions src/transformers/models/got_ocr2/configuration_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,6 @@ def __init__(
rope_config_validation(self, ignore_keys={"mrope_section"})

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


__all__ = ["GotOcr2VisionConfig", "GotOcr2Config"]
11 changes: 3 additions & 8 deletions src/transformers/models/got_ocr2/image_processing_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@
)


if is_vision_available():
import PIL


if is_vision_available():
import PIL


if is_vision_available():
import PIL

Expand Down Expand Up @@ -413,3 +405,6 @@ def crop_image_to_patches(
processed_images = processed_images_numpy

return processed_images


__all__ = ["GotOcr2ImageProcessor"]
8 changes: 8 additions & 0 deletions src/transformers/models/got_ocr2/modeling_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_got_ocr2 import GotOcr2Config, GotOcr2VisionConfig

Expand All @@ -54,6 +55,8 @@

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "GotOcr2Config"


class GotOcr2MLPBlock(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -1585,6 +1588,8 @@ def _update_model_kwargs_for_generation(

return model_kwargs

@add_start_docstrings_to_model_forward(GOT_OCR2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -1772,3 +1777,6 @@ def prepare_inputs_for_generation(
}
)
return model_inputs


__all__ = ["GotOcr2PreTrainedModel", "GotOcr2Model", "GotOcr2ForConditionalGeneration"]
11 changes: 11 additions & 0 deletions src/transformers/models/got_ocr2/modular_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,3 +898,14 @@ def prepare_inputs_for_generation(
}
)
return model_inputs


__all__ = [
"GotOcr2VisionConfig",
"GotOcr2Config",
"GotOcr2Processor",
"GotOcr2PreTrainedModel",
"GotOcr2Model",
"GotOcr2ForConditionalGeneration",
"GotOcr2ImageProcessor",
]
23 changes: 15 additions & 8 deletions src/transformers/models/got_ocr2/processing_got_ocr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
from ...utils import is_vision_available, logging


if is_vision_available():
from ...image_utils import load_images


if is_vision_available():
from ...image_utils import load_images

Expand Down Expand Up @@ -147,6 +143,12 @@ def __call__(
If set, will enable multi-page inference. The model will return the OCR result across multiple pages.
crop_to_patches (`bool`, *optional*):
If set, will crop the image to patches. The model will return the OCR result upon the patch reference.
min_patches (`int`, *optional*):
The minimum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
`True`.
max_patches (`int`, *optional*):
The maximum number of patches to be cropped from the image. Only used when `crop_to_patches` is set to
`True`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
Expand Down Expand Up @@ -178,6 +180,8 @@ def __call__(
color = output_kwargs["images_kwargs"].pop("color", None)
multi_page = output_kwargs["images_kwargs"].pop("multi_page", False)
crop_to_patches = output_kwargs["images_kwargs"].pop("crop_to_patches", False)
min_patches = output_kwargs["images_kwargs"].pop("min_patches")
max_patches = output_kwargs["images_kwargs"].pop("max_patches")

if not isinstance(box, (list, tuple)):
raise ValueError("`box` must be a list or tuple in the form [x1, y1, x2, y2].")
Expand Down Expand Up @@ -219,8 +223,8 @@ def __call__(
image_group = self.image_processor.crop_image_to_patches(
image_group,
size=output_kwargs["images_kwargs"].get("size"),
min_num=output_kwargs["images_kwargs"].get("min_patches"),
max_num=output_kwargs["images_kwargs"].get("max_patches"),
min_num=min_patches,
max_num=max_patches,
)
images[index] = image_group
num_images = len(image_group) if (multi_page or crop_to_patches) else 1
Expand Down Expand Up @@ -259,14 +263,14 @@ def __call__(

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
Expand All @@ -276,3 +280,6 @@ def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))


__all__ = ["GotOcr2Processor"]

0 comments on commit 4007fb2

Please sign in to comment.