Skip to content

Commit

Permalink
Switch to the Composer integration of LoRA (works with FSDP) (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored and bigning committed Feb 5, 2024
1 parent c1ee2e9 commit 3cf4804
Show file tree
Hide file tree
Showing 10 changed files with 624 additions and 356 deletions.
52 changes: 41 additions & 11 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import tempfile
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand Down Expand Up @@ -203,14 +203,17 @@ def _save_checkpoint(self, state: State, logger: Logger):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
composer_model = state.model.module
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
composer_model = state.model
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
composer_model = state.model
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
Expand All @@ -237,10 +240,25 @@ def _save_checkpoint(self, state: State, logger: Logger):
copied_config.init_device = 'cpu'

log.debug(f'Creating new model instance')
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

if composer_model.using_peft:
# We don't use meta here because the state dict does not contain the full
# model, only the adapter weights.
active_adapter = original_model.active_adapter
base_model = original_model.get_base_model()
new_base_model_instance = type(base_model)(copied_config)

new_model_instance = type(original_model)(
new_base_model_instance,
original_model.peft_config[active_adapter])
else:
# First create the model instance on meta device to avoid the
# initialization cost.
with init_empty_weights():
new_model_instance = type(original_model)(copied_config)

new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)

# Then load the state dict in with "assign" so that the state dict
# is loaded properly even though the model is initially on meta device.
Expand Down Expand Up @@ -295,12 +313,24 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
model_saving_kwargs: Dict[str, Any] = {
'path': local_save_path
}
if composer_model.using_peft:
model_saving_kwargs['flavor'] = 'peft'
model_saving_kwargs[
'save_pretrained_dir'] = temp_save_dir
model_saving_kwargs[
'metadata'] = self.mlflow_logging_config[
'metadata']
else:
model_saving_kwargs['flavor'] = 'transformers'
model_saving_kwargs[
'transformers_model'] = components
model_saving_kwargs.update(
self.mlflow_logging_config)

mlflow_logger.save_model(**model_saving_kwargs)

license_filename = _maybe_get_license_filename(
local_save_path)
Expand Down
Loading

0 comments on commit 3cf4804

Please sign in to comment.