diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 56307da0b8..ad332da2da 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1,6 +1,7 @@ """ Module for pydantic models for configuration """ + # pylint: disable=too-many-lines import logging @@ -655,6 +656,20 @@ def check_sample_packing_w_xformers(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_sample_packing_wo_flash(cls, data): + if ( + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + ): + raise ValueError( + "sample_packing requires flash_attention or sdp_attention to be set to true" + ) + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_w_rl(cls, data): diff --git a/tests/test_validation.py b/tests/test_validation.py index 7a8d80cb75..4865712c47 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -600,6 +600,7 @@ def test_packing(self, minimal_cfg): { "sample_packing": True, "pad_to_sequence_len": None, + "flash_attention": True, } ) | minimal_cfg @@ -901,6 +902,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): { "sample_packing": True, "eval_table_size": 100, + "flash_attention": True, } ) | minimal_cfg @@ -916,6 +918,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): { "sample_packing": True, "eval_sample_packing": False, + "flash_attention": True, } ) | minimal_cfg @@ -928,6 +931,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): { "sample_packing": False, "eval_table_size": 100, + "flash_attention": True, } ) | minimal_cfg @@ -941,6 +945,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): "sample_packing": True, "eval_table_size": 100, "eval_sample_packing": False, + "flash_attention": True, } ) | minimal_cfg