From 97b45291fbe828ea0544b2ccef22153d1d5566d5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 19 Nov 2024 00:45:00 -0500 Subject: [PATCH] monkeypatch for zero3 w 8bit lora --- .../monkeypatch/modeling_zero3_int8_lora.py | 84 +++++++++++++++++++ src/axolotl/utils/trainer.py | 9 ++ 2 files changed, 93 insertions(+) create mode 100644 src/axolotl/monkeypatch/modeling_zero3_int8_lora.py diff --git a/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py b/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py new file mode 100644 index 0000000000..29a09e3472 --- /dev/null +++ b/src/axolotl/monkeypatch/modeling_zero3_int8_lora.py @@ -0,0 +1,84 @@ +""" +fix for zero3 8-bit lora +see https://github.com/huggingface/transformers/pull/32943/files +""" +import inspect + +import transformers +import transformers.modeling_utils +from accelerate.logging import get_logger + +LOG = get_logger("axolotl.monkeypatch.modeling_zero3_int8_lora") + +ORIGINAL_LOAD_CODE = """ + if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + module, tensor_name = get_module_from_name(model, param_name) + value = getattr(module, tensor_name) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + value = type(value)(value.data.to(param_to), **value.__dict__) + setattr(module, tensor_name, value) +""" + +PATCHED_LOAD_CODE = """ + if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): + module, tensor_name = get_module_from_name(model, param_name) + value = getattr(module, tensor_name) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + val_kwargs = {} + if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params": + val_kwargs["requires_grad"] = False + value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__) + setattr(module, tensor_name, value) +""" + + +def get_modeling_state_dict_code() -> str: + load_code = inspect.getsource( + transformers.modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access + ) + return load_code + + +def check_modeling_state_dict_code_is_patchable() -> bool: + load_code = get_modeling_state_dict_code() + return ORIGINAL_LOAD_CODE in load_code + + +def patch_modeling_state_dict_code(): + """ + monkeypatch for fixing the meta model loader for zero3 8-bit lora + """ + + load_code = get_modeling_state_dict_code() + transformers.modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access + load_code + ) + assert ( + ORIGINAL_LOAD_CODE in load_code + ), "Original _load_state_dict_into_meta_model code not found" + + load_code = load_code.replace(ORIGINAL_LOAD_CODE, PATCHED_LOAD_CODE) + load_code = load_code.replace( + "def _load_state_dict_into_meta_model(", + "def _fixed_load_state_dict_into_meta_model(", + 1, + ) + + items_to_import = [] + for item in dir(transformers.modeling_utils): + if item in load_code: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.modeling_utils import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(load_code, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching _load_state_dict_into_meta_model", main_process_only=True) + transformers.modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 5c9bfd6635..e99c47492e 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -437,6 +437,15 @@ def setup_deepspeed_env(cfg, stage=None): os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage) if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" + if cfg.adapter and cfg.load_in_8bit: + from axolotl.monkeypatch.modeling_zero3_int8_lora import ( + patch_modeling_state_dict_code, + ) + + try: + patch_modeling_state_dict_code() + except AssertionError: + LOG.warning("Failed to patch the meta model loading code") # If we don't assign this, it doesn't actually get set in the accelerate weakref _ = HfTrainerDeepSpeedConfig(cfg.deepspeed)