Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for transforms on the model before MLFlow registration #1372

Merged
merged 17 commits into from
Jul 20, 2024
173 changes: 93 additions & 80 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch
import torch.nn as nn
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
from composer.utils import (
Expand All @@ -29,7 +28,12 @@
)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from packaging import version
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import (
PretrainedConfig,
PreTrainedModel,
Expand Down Expand Up @@ -179,6 +183,7 @@ def __init__(
'bfloat16': torch.bfloat16,
}[precision]
self.flatten_imports = flatten_imports
self.using_peft = False

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
Expand Down Expand Up @@ -274,6 +279,15 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'1GB',
)

# Check if the model is using PEFT
if state.is_model_ddp:
composer_model = state.model.module
elif isinstance(state.model.model, FSDP):
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
composer_model = state.model
else:
composer_model = state.model
self.using_peft = composer_model.using_peft
elif event == Event.FIT_END:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
Expand Down Expand Up @@ -362,6 +376,23 @@ def transform_config(
copied_config.ffn_config['moe_world_size'] = 1
return copied_config

def transform_model_pre_registration(
self,
model: PreTrainedModel,
) -> PreTrainedModel:
"""Transform the model before registering with MLFlow.
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved

This allows a subclass to modify the model before registering with MLFlow. The base class implementation will
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
make no modifications.

Args:
model (PreTrainedModel): The model to be transformed.

Returns:
PreTrainedModel: The transformed model.
"""
return model

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand All @@ -388,82 +419,62 @@ def _save_checkpoint(self, state: State, logger: Logger):
temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir

log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
)
cpu_offload = True

# Add a dtensor->cpu tensor hook to avoid CUDA OOM
def dtensor_to_tensor_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
dtensor_fqns = []
for fqn in state_dict.keys():
tensor = state_dict[fqn]
if isinstance(tensor, DTensor):
dtensor_fqns.append(fqn)
tensor = tensor.full_tensor() # type: ignore
if dist.get_global_rank() == 0:
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor
if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
return state_dict

hooks = []
for _, module in state_dict_model.named_modules():
if isinstance(module, FSDP):
hooks.append(
module.
_register_state_dict_hook(dtensor_to_tensor_hook),
)
cpu_offload = True

# Add a dtensor->cpu tensor hook to avoid CUDA OOM
def dtensor_to_tensor_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
dtensor_fqns = []
for fqn in state_dict.keys():
tensor = state_dict[fqn]
if isinstance(tensor, DTensor):
dtensor_fqns.append(fqn)
tensor = tensor.full_tensor() # type: ignore
if dist.get_global_rank() == 0:
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor
if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
return state_dict

hooks = []
for _, module in state_dict_model.named_modules():
if isinstance(module, FSDP):
hooks.append(
module._register_state_dict_hook(dtensor_to_tensor_hook),
)

state_dict = get_model_state_dict(
state_dict_model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=cpu_offload,
),
)
for hook in hooks:
hook.remove()
else:
state_dict_context = fsdp_state_dict_type_context(
original_model,
state_dict_type='full',
) if ((not state.is_model_ddp) and
isinstance(state_dict_model,
FSDP)) else contextlib.nullcontext()
with state_dict_context:
state_dict = state_dict_model.state_dict()

# Convert the state dict to the requested precis
state_dict = get_model_state_dict(
state_dict_model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=cpu_offload,
),
)
for hook in hooks:
hook.remove()

# Convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
Expand All @@ -480,22 +491,19 @@ def dtensor_to_tensor_hook(

log.debug(f'Creating new model instance')

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(new_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter],
)
new_model_instance.to(dtype=self.dtype)
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
if self.using_peft:
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(new_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter],
)
else:
new_model_instance = type(original_model)(new_config)
new_model_instance.generation_config.update(
**original_model.generation_config.to_dict(),
Expand Down Expand Up @@ -556,6 +564,11 @@ def dtensor_to_tensor_hook(

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):

new_model_instance = self.transform_model_pre_registration(
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
new_model_instance,
)

components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand All @@ -575,7 +588,7 @@ def dtensor_to_tensor_hook(
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path,
}
if composer_model.using_peft:
if self.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs['save_pretrained_dir'
] = temp_save_dir
Expand Down
5 changes: 5 additions & 0 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def test_huggingface_conversion_callback_interval(
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
mlflow_logger_mock._run_id = 'mlflow-run-id'
checkpointer_callback.transform_model_pre_registration = MagicMock(
wraps=checkpointer_callback.transform_model_pre_registration,
)
trainer = Trainer(
model=original_model,
device='gpu',
Expand All @@ -407,8 +410,10 @@ def test_huggingface_conversion_callback_interval(
input_example=ANY,
metadata={},
)
assert checkpointer_callback.transform_model_pre_registration.call_count == 1
assert mlflow_logger_mock.register_model_with_run_id.call_count == 1
else:
assert checkpointer_callback.transform_model_pre_registration.call_count == 0
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model_with_run_id.call_count == 0

Expand Down
Loading