From 151abb7a672a63a29e42fccd7ed57f4abd8b3337 Mon Sep 17 00:00:00 2001 From: Sunny Liu <22844540+bursteratom@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:36:51 -0500 Subject: [PATCH] =?UTF-8?q?fix=20None-type=20not=20iterable=20error=20when?= =?UTF-8?q?=20deepspeed=20is=20left=20blank=20w/=20use=5F=E2=80=A6=20(#208?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix None-type not iterable error when deepspeed is left blank w/ use_reentrant: false and qlora * added unit test[skip e2e] * corrected test case[skip e2e] * assert warning message [skip e2e] * assert warning message [skip e2e] * corrected test cases [skip e2e] * lint --- .../config/models/input/v0_4_1/__init__.py | 1 + tests/test_validation.py | 47 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 42cbe52c14..cdbe47b8f1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1314,6 +1314,7 @@ def warn_qlora_zero3_w_use_reentrant(cls, data): and data.get("gradient_checkpointing_kwargs", {}) and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant") is False + and data.get("deepspeed", "") is not None and "zero3" in data.get("deepspeed", "") ): # may result in: diff --git a/tests/test_validation.py b/tests/test_validation.py index f3f4d18ab8..491f230c33 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -68,6 +68,53 @@ def test_defaults(self, minimal_cfg): assert cfg.train_on_inputs is False assert cfg.weight_decay is None + def test_zero3_qlora_use_reentrant_false(self, minimal_cfg): + test_cfg = DictDefault( + { + "deepspeed": "deepspeed_configs/zero3_bf16.json", + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "load_in_4bit": True, + "adapter": "qlora", + } + | minimal_cfg + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(test_cfg) + assert ( + "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values" + in self._caplog.records[0].message + ) + + def test_deepspeed_empty(self, minimal_cfg): + test_cfg = DictDefault( + { + "deepspeed": "", + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "load_in_4bit": True, + "adapter": "qlora", + } + | minimal_cfg + ) + + _ = validate_config(test_cfg) + + def test_deepspeed_not_set(self, minimal_cfg): + test_cfg = DictDefault( + { + "deepspeed": None, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": False}, + "load_in_4bit": True, + "adapter": "qlora", + } + | minimal_cfg + ) + + _ = validate_config(test_cfg) + def test_datasets_min_length(self): cfg = DictDefault( {