Skip to content

Commit

Permalink
Move transform_model_pre_registration in hf_checkpointer (#1664)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
irenedea and dakinggg authored Nov 18, 2024
1 parent e2cc41b commit 8a1e55e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 8 deletions.
11 changes: 4 additions & 7 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,10 @@ def tensor_hook(

if dist.get_global_rank() == 0:
if register_to_mlflow:
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
if self.using_peft:

# Save and register peft model to mlflow, this code path uses our older two step logic
Expand All @@ -798,10 +802,6 @@ def tensor_hook(
temp_save_dir,
'register_save',
)
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
new_model_instance.save_pretrained(
register_save_dir,
max_shard_size='1GB',
Expand Down Expand Up @@ -860,9 +860,6 @@ def _save_and_register_peft_model(
original_tokenizer: Optional[Any],
save_dir: str,
):
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand Down
72 changes: 71 additions & 1 deletion tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def test_huggingface_conversion_callback_interval(
def _get_model_and_tokenizer(
model: str,
max_seq_len: int,
tie_word_embeddings: bool,
tie_word_embeddings: Optional[bool],
precision: str,
):
if model == 'mpt':
Expand Down Expand Up @@ -1110,6 +1110,76 @@ def test_huggingface_conversion_callback(
delete_transformers_cache()


@patch('os.cpu_count', MagicMock(return_value=1))
@patch(
'llmfoundry.callbacks.hf_checkpointer.SpawnProcess',
new=MockSpawnProcess,
)
def test_transform_model_pre_registration():
"""Test `transform_model_pre_registration` method is called."""

class ExtendedHuggingFaceCheckpointer(HuggingFaceCheckpointer):
"""Set PEFT to false before registering for testing."""

def transform_model_pre_registration(self, model: PreTrainedModel):
self.using_peft = False
return super().transform_model_pre_registration(model)

model_cfg, tokenizer_name = _get_model_and_tokenizer(
model='neo',
max_seq_len=10,
tie_word_embeddings=None,
precision='bfloat16',
)
model_cfg['peft_config'] = {
'peft_type': 'LORA',
'task_type': 'CAUSAL_LM',
'lora_alpha': 32,
'lora_dropout': 0.05,
'r': 16,
'target_modules': 'all-linear',
}
tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
tokenizer_kwargs={},
)

original_model = build_composer_model(
model_cfg.pop('name'),
tokenizer=tokenizer,
cfg=model_cfg,
)

logger = MagicMock()
state = MagicMock()
state.timestamp.batch = 1
state.is_model_ddp = False
state.model = original_model
state.model.tokenizer = tokenizer

checkpointer = ExtendedHuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
)
mlflow_logger_mock = _create_mlflow_logger_mock()
checkpointer.mlflow_loggers = [mlflow_logger_mock] # type: ignore

assert model_cfg is not None
assert tokenizer_name is not None

checkpointer._save_and_register_peft_model = MagicMock()
checkpointer.using_peft = True
checkpointer._save_checkpoint(
state=state,
logger=logger,
upload_to_save_folder=True,
register_to_mlflow=True,
)

checkpointer._save_and_register_peft_model.assert_not_called()
assert mlflow_logger_mock.log_model.call_count == 1


# TODO(GRT-2431): Refactor as enums
@pytest.mark.parametrize(
'model,tie_word_embeddings',
Expand Down

0 comments on commit 8a1e55e

Please sign in to comment.