Skip to content

Commit

Permalink
Add adapter_only option to save_fsdp_model and load_fsdp_model
Browse files Browse the repository at this point in the history
…to only save/load PEFT weights (#2321)

* Add adapter_only option to save_fsdp_model and load_fsdp_model

* Gate with adapter_only

* Black format

* Change unwrapping behavior

* Use extract_model_from_parallel for model unwrapping

* Fix quality

* Move functions to utils files

* Fix quality
  • Loading branch information
AjayP13 authored Jan 26, 2024
1 parent e909eb3 commit 581fabb
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
is_msamp_available,
is_npu_available,
is_pandas_available,
is_peft_available,
is_rich_available,
is_sagemaker_available,
is_tensorboard_available,
Expand All @@ -94,6 +95,7 @@
get_mixed_precision_context_manager,
id_tensor_storage,
infer_auto_device_map,
is_peft_model,
load_checkpoint_in_model,
load_offloaded_weights,
load_state_dict,
Expand Down
33 changes: 26 additions & 7 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..logging import get_logger
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME
from .imports import is_torch_distributed_available
from .modeling import is_peft_model
from .versions import is_torch_version


Expand All @@ -32,7 +33,25 @@
logger = get_logger(__name__)


def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0):
def _get_model_state_dict(model, adapter_only=False):
if adapter_only and is_peft_model(model):
from peft import get_peft_model_state_dict

return get_peft_model_state_dict(model, adapter_name=model.active_adapter)
else:
return model.state_dict()


def _set_model_state_dict(model, state_dict, adapter_only=False):
if adapter_only and is_peft_model(model):
from peft import set_peft_model_state_dict

return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter)
else:
return model.load_state_dict(state_dict)


def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
os.makedirs(output_dir, exist_ok=True)

if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
Expand All @@ -45,7 +64,7 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0):
with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
):
state_dict = model.state_dict()
state_dict = _get_model_state_dict(model, adapter_only=adapter_only)
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
output_model_file = os.path.join(output_dir, weights_name)
Expand Down Expand Up @@ -77,7 +96,7 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0):
logger.info(f"Model saved to {ckpt_dir}")


def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0):
def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
accelerator.wait_for_everyone()
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
Expand Down Expand Up @@ -118,15 +137,15 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0):
else input_dir
)
logger.info(f"Loading model from {ckpt_dir}")
state_dict = {"model": model.state_dict()}
state_dict = {"model": _get_model_state_dict(model, adapter_only=adapter_only)}
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(ckpt_dir),
planner=DefaultLoadPlanner(),
)
state_dict = state_dict["model"]
logger.info(f"Model loaded from {ckpt_dir}")
load_result = model.load_state_dict(state_dict)
load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only)
return load_result


Expand Down Expand Up @@ -157,7 +176,7 @@ def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir,
logger.info(f"Optimizer state saved in {ckpt_dir}")


def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0):
def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
accelerator.wait_for_everyone()
with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
Expand All @@ -180,7 +199,7 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
)
logger.info(f"Loading Optimizer from {ckpt_dir}")
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=model.state_dict(),
model_state_dict=_get_model_state_dict(model, adapter_only=adapter_only),
optimizer_key="optimizer",
storage_reader=dist_cp.FileSystemReader(ckpt_dir),
)
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def is_datasets_available():
return _is_package_available("datasets")


def is_peft_available():
return _is_package_available("peft")


def is_timm_available():
return _is_package_available("timm")

Expand Down
11 changes: 10 additions & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..state import AcceleratorState
from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
from .imports import is_mps_available, is_npu_available, is_xpu_available
from .imports import is_mps_available, is_npu_available, is_peft_available, is_xpu_available
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm

Expand All @@ -47,6 +47,15 @@
logger = logging.getLogger(__name__)


def is_peft_model(model):
from .other import extract_model_from_parallel

if is_peft_available():
from peft import PeftModel

return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel)


def check_device_same(first_device, second_device):
"""
Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False`
Expand Down

0 comments on commit 581fabb

Please sign in to comment.