From 0258544aa191ce36ec9f9688b4ad7c1ec451bbc8 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 26 Sep 2024 17:23:53 +0000 Subject: [PATCH] fix low_cpu_mem intro by 4.45 Signed-off-by: Yu Chin Fabian Lim --- .../framework_plugin_bnb.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py index 1fd6afa8..5c4624c0 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py @@ -25,6 +25,7 @@ from fms_acceleration import AccelerationPlugin from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments +from transformers.utils.import_utils import _is_package_available import torch @@ -120,6 +121,24 @@ def model_loader(self, model_name: str, **kwargs): and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" ): config_kwargs["bnb_4bit_quant_storage"] = torch_dtype + + _, _transformers_version = _is_package_available("transformers", return_version=True) + + if _transformers_version >= "4.45": + from fms_acceleration.model_patcher import patch_target_module + + def _truthy(): + return True + + patch_target_module( + "transformers.modeling_utils.is_local_dist_rank_0", + _truthy, + ) + warnings.warn( + "Disabling low_cpu_mem_mode as this will cause problems with " + "the fused-ops-and-kernels package" + ) + elif world_size > 1: warnings.warn( "Running in distributed mode but bnb_4bit_quant_storage is not set. "