Skip to content

Commit

Permalink
Fix loading with only state dict and low_cpu_mem_usage = True (#35217)
Browse files Browse the repository at this point in the history
* fix loading with only state dict and config

* style

* add tests

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
SunMarc and sayakpaul authored Dec 18, 2024
1 parent 0531d75 commit 1eee1ce
Showing 2 changed files with 26 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -4022,8 +4022,11 @@ def from_pretrained(
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = list(state_dict.keys())

if gguf_path is None and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available())):
if (
gguf_path is None
and (low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()))
and pretrained_model_name_or_path is not None
):
# In case some weights need to be kept in float32 and accelerate is not installed,
# we later on want to take the path where state_dict is not None, that is the one
# that do not require accelerate.
@@ -4679,7 +4682,7 @@ def _find_mismatched_keys(
)

# For GGUF models `state_dict` is never set to None as the state dict is always small
if gguf_path:
if gguf_path or low_cpu_mem_usage:
fixed_state_dict = cls._fix_state_dict_keys_on_load(state_dict)
error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
20 changes: 20 additions & 0 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
@@ -1750,6 +1750,26 @@ def test_save_and_load_config_with_custom_generation(self):
new_model.generate(random_ids, max_new_tokens=3)
self.assertTrue(len(w) == 0)

def test_load_model_with_state_dict_only(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
state_dict = model.state_dict()
config = model.config

model_loaded = BertModel.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict
)
self.assertTrue(check_models_equal(model, model_loaded))

def test_load_model_with_state_dict_only_low_cpu_mem_usage(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
state_dict = model.state_dict()
config = model.config

model_loaded = BertModel.from_pretrained(
pretrained_model_name_or_path=None, config=config, state_dict=state_dict, low_cpu_mem_usage=True
)
self.assertTrue(check_models_equal(model, model_loaded))


@slow
@require_torch

0 comments on commit 1eee1ce

Please sign in to comment.