Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for transforms on the model before MLFlow registration #1372

Merged
merged 17 commits into from
Jul 20, 2024
59 changes: 41 additions & 18 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
'bfloat16': torch.bfloat16,
}[precision]
self.flatten_imports = flatten_imports
self.using_peft = False

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
Expand Down Expand Up @@ -362,6 +363,23 @@ def transform_config(
copied_config.ffn_config['moe_world_size'] = 1
return copied_config

def transform_model_pre_registration(
self,
model: PreTrainedModel,
) -> PreTrainedModel:
"""Transform the model before registering with MLFlow.
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

This allows a subclass to modify the model before registering with MLFlow. The base class implementation will
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
make no modifications.

Args:
model (PreTrainedModel): The model to be transformed.

Returns:
PreTrainedModel: The transformed model.
"""
return model

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand Down Expand Up @@ -406,6 +424,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

self.using_peft = composer_model.using_peft
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

if version.parse(torch.__version__) > version.parse('2.2.9'):
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
Expand Down Expand Up @@ -463,7 +483,7 @@ def dtensor_to_tensor_hook(
with state_dict_context:
state_dict = state_dict_model.state_dict()

# Convert the state dict to the requested precis
# Convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
Expand All @@ -480,22 +500,20 @@ def dtensor_to_tensor_hook(

log.debug(f'Creating new model instance')

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(new_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter],
)
new_model_instance.to(dtype=self.dtype)
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
if self.using_peft:
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(new_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter],
)
new_model_instance.to(dtype=self.dtype)
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
else:
new_model_instance = type(original_model)(new_config)
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
Expand Down Expand Up @@ -556,6 +574,11 @@ def dtensor_to_tensor_hook(

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):

new_model_instance = self.transform_model_pre_registration(
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
new_model_instance,
)

components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand All @@ -575,7 +598,7 @@ def dtensor_to_tensor_hook(
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path,
}
if composer_model.using_peft:
if self.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'
] = temp_save_dir
Expand Down
Loading