Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
fix import CompileConfig
  • Loading branch information
alestrami authored Dec 20, 2024
1 parent 893d6a4 commit b71093c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from .activations import get_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .generation.configuration_utils import CompileConfig, GenerationConfig
from .generation import GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
Expand All @@ -54,7 +55,6 @@
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
Expand Down Expand Up @@ -476,7 +476,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": True}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)

for shard_file in shard_files:
Expand Down Expand Up @@ -532,7 +532,7 @@ def load_state_dict(
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": weights_only}
return torch.load(
checkpoint_file,
map_location=map_location,
Expand Down

0 comments on commit b71093c

Please sign in to comment.