diff --git a/docs/config.qmd b/docs/config.qmd index f01a2ce267..dd2da41b9a 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -239,6 +239,9 @@ sample_packing_group_size: 100000 # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. sample_packing_bin_size: 200 +# Use batch flattening for speedups when not using sample_packing +batch_flattening: + # Passed through to transformers when loading the model when launched without accelerate # Use `sequential` when training w/ model parallelism to limit memory device_map: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 8860e06404..3862351522 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -696,7 +696,7 @@ class Config: curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None - batch_flattening: Optional[bool] = None + batch_flattening: Optional[Union[Literal["auto"], bool]] = None # for PoSE context length extension use_pose: Optional[bool] = None @@ -929,13 +929,24 @@ def check_sample_packing_wo_flash(cls, data): @classmethod def check_batch_flattening_fa(cls, data): if data.get("batch_flattening"): - if not data.get("flash_attention"): + batch_flattening_auto = data.get("batch_flattening") == "auto" + if not data.get("flash_attention") and not batch_flattening_auto: raise ValueError("batch_flattening requires flash attention") - if data.get("sample_packing"): + if data.get("sample_packing") and not batch_flattening_auto: raise ValueError("batch_flattening not compatible with sample_packing") - if data.get("micro_batch_size") == 1: + if data.get("micro_batch_size") == 1 and not batch_flattening_auto: LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + if ( + batch_flattening_auto + and data.get("flash_attention") + and not data.get("sample_packing") + and data.get("micro_batch_size") > 1 + ): + data["batch_flattening"] = True + elif batch_flattening_auto: + data["batch_flattening"] = False + return data @model_validator(mode="before") diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 2e6fbab101..7745fbfeae 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1196,6 +1196,76 @@ def test_torch_version_adopt_req(self, minimal_cfg): ) +class TestSampleOptimConfigValidation(BaseValidation): + """ + test configurations for sample optimizations like batch flattening + """ + + def test_batch_flattening_auto_enables(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": None, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is True + + def test_batch_flattening_auto_no_fa(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": False, + "sample_packing": None, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + def test_batch_flattening_auto_mbsz_1(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": None, + "micro_batch_size": 1, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + def test_batch_flattening_auto_packing(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": True, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + class TestValidationCheckModelConfig(BaseValidation): """ Test the validation for the config when the model config is available