diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 49d086c76e8683..50984736b45a19 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -43,7 +43,8 @@ from .activations import get_activation from .configuration_utils import PretrainedConfig from .dynamic_module_utils import custom_object_save -from .generation import CompileConfig, GenerationConfig, GenerationMixin +from .generation.configuration_utils import CompileConfig, GenerationConfig +from .generation import GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled from .integrations.flash_attention import flash_attention_forward from .integrations.flex_attention import flex_attention_forward @@ -54,7 +55,6 @@ apply_chunking_to_forward, find_pruneable_heads_and_indices, id_tensor_storage, - is_torch_greater_or_equal_than_1_13, prune_conv1d_layer, prune_layer, prune_linear_layer, @@ -476,7 +476,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): error_message += f"\nMissing key(s): {str_unexpected_keys}." raise RuntimeError(error_message) - weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": True} loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg) for shard_file in shard_files: @@ -532,7 +532,7 @@ def load_state_dict( and is_zipfile(checkpoint_file) ): extra_args = {"mmap": True} - weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {} + weights_only_kwarg = {"weights_only": weights_only} return torch.load( checkpoint_file, map_location=map_location,