From 3188d55d1fc45a0d4014ef4107d63e16c630c211 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 31 Mar 2024 02:00:29 +0900 Subject: [PATCH 1/3] feat: validate sample packing requires flash_attention --- .../utils/config/models/input/v0_4_1/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 56307da0b8..e76b36c4fe 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1,6 +1,7 @@ """ Module for pydantic models for configuration """ + # pylint: disable=too-many-lines import logging @@ -655,6 +656,16 @@ def check_sample_packing_w_xformers(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_sample_packing_wo_flash(cls, data): + if data.get("sample_packing") and not data.get("flash_attention"): + raise ValueError( + "sample_packing requires flash_attention to be set to true" + ) + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_w_rl(cls, data): From 572c9e49543260c9dd5a0f5121c10a99c5e1efae Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 31 Mar 2024 04:00:00 +0900 Subject: [PATCH 2/3] fix: check for sdp_attn per suggestion --- src/axolotl/utils/config/models/input/v0_4_1/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index e76b36c4fe..ad332da2da 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -659,9 +659,13 @@ def check_sample_packing_w_xformers(cls, data): @model_validator(mode="before") @classmethod def check_sample_packing_wo_flash(cls, data): - if data.get("sample_packing") and not data.get("flash_attention"): + if ( + data.get("sample_packing") + and not data.get("flash_attention") + and not data.get("sdp_attention") + ): raise ValueError( - "sample_packing requires flash_attention to be set to true" + "sample_packing requires flash_attention or sdp_attention to be set to true" ) return data From df936a95b5a25a2655c70066f525fced25625786 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 4 Apr 2024 13:51:03 +0900 Subject: [PATCH 3/3] feat: add FA to tests --- tests/test_validation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_validation.py b/tests/test_validation.py index 7a8d80cb75..4865712c47 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -600,6 +600,7 @@ def test_packing(self, minimal_cfg): { "sample_packing": True, "pad_to_sequence_len": None, + "flash_attention": True, } ) | minimal_cfg @@ -901,6 +902,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): { "sample_packing": True, "eval_table_size": 100, + "flash_attention": True, } ) | minimal_cfg @@ -916,6 +918,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): { "sample_packing": True, "eval_sample_packing": False, + "flash_attention": True, } ) | minimal_cfg @@ -928,6 +931,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): { "sample_packing": False, "eval_table_size": 100, + "flash_attention": True, } ) | minimal_cfg @@ -941,6 +945,7 @@ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg): "sample_packing": True, "eval_table_size": 100, "eval_sample_packing": False, + "flash_attention": True, } ) | minimal_cfg