Skip to content

Commit

Permalink
[Init refactor] Modular changes (#35240)
Browse files Browse the repository at this point in the history
* Modular changes

* Gemma

* Gemma
  • Loading branch information
LysandreJik authored Dec 12, 2024
1 parent a691ccb commit 11ba1d4
Show file tree
Hide file tree
Showing 19 changed files with 129 additions and 253 deletions.
113 changes: 10 additions & 103 deletions src/transformers/models/gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,111 +13,18 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tokenizers_available,
is_torch_available,
)


_import_structure = {
"configuration_gemma": ["GemmaConfig"],
}

try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gemma"] = ["GemmaTokenizer"]

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"]


try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gemma"] = [
"GemmaForCausalLM",
"GemmaModel",
"GemmaPreTrainedModel",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
]

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_gemma"] = [
"FlaxGemmaForCausalLM",
"FlaxGemmaModel",
"FlaxGemmaPreTrainedModel",
]
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_gemma import GemmaConfig

try:
if not is_sentencepiece_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gemma import GemmaTokenizer

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_gemma_fast import GemmaTokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gemma import (
GemmaForCausalLM,
GemmaForSequenceClassification,
GemmaForTokenClassification,
GemmaModel,
GemmaPreTrainedModel,
)

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_gemma import (
FlaxGemmaForCausalLM,
FlaxGemmaModel,
FlaxGemmaPreTrainedModel,
)


from .configuration_gemma import *
from .modeling_flax_gemma import *
from .modeling_gemma import *
from .tokenization_gemma import *
from .tokenization_gemma_fast import *
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/modeling_flax_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,6 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs):
_CONFIG_FOR_DOC,
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC,
)


__all__ = ["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"]
8 changes: 7 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,4 +1295,10 @@ def forward(
)


__all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]
__all__ = [
"GemmaModel",
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]
6 changes: 6 additions & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down Expand Up @@ -803,6 +804,10 @@ def forward(
return outputs


class GemmaPreTrainedModel(LlamaPreTrainedModel):
pass


class GemmaModel(LlamaModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
Expand Down Expand Up @@ -1040,4 +1045,5 @@ def __init__(self, config):
"GemmaForCausalLM",
"GemmaForSequenceClassification",
"GemmaForTokenClassification",
"GemmaPreTrainedModel",
]
3 changes: 3 additions & 0 deletions src/transformers/models/gemma/tokenization_gemma_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
output = output + bos_token_id + token_ids_1 + eos_token_id

return output


__all__ = ["GemmaTokenizerFast"]
48 changes: 7 additions & 41 deletions src/transformers/models/gemma2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,49 +13,15 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


_import_structure = {
"configuration_gemma2": ["Gemma2Config"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_gemma2"] = [
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]

if TYPE_CHECKING:
from .configuration_gemma2 import Gemma2Config

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_gemma2 import (
Gemma2ForCausalLM,
Gemma2ForSequenceClassification,
Gemma2ForTokenClassification,
Gemma2Model,
Gemma2PreTrainedModel,
)

from .configuration_gemma2 import *
from .modeling_gemma2 import *
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
3 changes: 3 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ def __init__(
self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.cache_implementation = cache_implementation


__all__ = ["Gemma2Config"]
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,3 +1280,12 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


__all__ = [
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]
10 changes: 10 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,3 +903,13 @@ def __init__(self, config):
super().__init__(config)
self.model = Gemma2Model(config)
self.post_init()


__all__ = [
"Gemma2Config",
"Gemma2ForCausalLM",
"Gemma2Model",
"Gemma2PreTrainedModel",
"Gemma2ForSequenceClassification",
"Gemma2ForTokenClassification",
]
57 changes: 8 additions & 49 deletions src/transformers/models/llava_next_video/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,17 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


_import_structure = {
"configuration_llava_next_video": ["LlavaNextVideoConfig"],
"processing_llava_next_video": ["LlavaNextVideoProcessor"],
}


try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_llava_next_video"] = ["LlavaNextVideoImageProcessor"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_llava_next_video"] = [
"LlavaNextVideoForConditionalGeneration",
"LlavaNextVideoPreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_llava_next_video import LlavaNextVideoConfig
from .processing_llava_next_video import LlavaNextVideoProcessor

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_llava_next_video import LlavaNextVideoImageProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_llava_next_video import (
LlavaNextVideoForConditionalGeneration,
LlavaNextVideoPreTrainedModel,
)

from .configuration_llava_next_video import *
from .image_processing_llava_next_video import *
from .modeling_llava_next_video import *
from .processing_llava_next_video import *
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ def __init__(
self.text_config = text_config

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


__all__ = ["LlavaNextVideoConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,6 @@ def preprocess(

data = {"pixel_values_videos": pixel_values}
return BatchFeature(data=data, tensor_type=return_tensors)


__all__ = ["LlavaNextVideoImageProcessor"]
Loading

0 comments on commit 11ba1d4

Please sign in to comment.