From 0e5bfcec99bf4f347c8536193800d7de123e28c8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 16 Mar 2024 08:34:55 -0400 Subject: [PATCH] make sure to capture non-null defaults from config validation --- src/axolotl/utils/config/__init__.py | 4 +-- .../config/models/input/v0_4_1/__init__.py | 28 ++++++++----------- tests/test_validation.py | 12 ++++++++ 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 9151f288a8..f7c269bea3 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -199,11 +199,11 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): dict( AxolotlConfigWCapabilities( **cfg.to_dict(), capabilities=capabilities - ).model_dump(exclude_unset=True) + ).model_dump(exclude_none=True) ) ) return DictDefault( - dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True)) + dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True)) ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index dfe9a9be96..00c95e207b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -148,12 +148,6 @@ class PeftConfig(BaseModel): loftq_config: Optional[LoftQConfig] = None -class AutoType(str, Enum): - """auto type string configuration subset - used for bf16""" - - AUTO = "auto" - - class SpecialTokensConfig(BaseModel): """Special tokens configuration subset""" @@ -304,14 +298,16 @@ class HyperparametersConfig(BaseModel): }, ) - train_on_inputs: Optional[bool] = None + train_on_inputs: Optional[bool] = False group_by_length: Optional[bool] = None learning_rate: Union[str, float] - weight_decay: Optional[float] = None - optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None + weight_decay: Optional[float] = 0.0 + optimizer: Optional[ + Union[OptimizerNames, Literal["lion_pytorch"]] + ] = OptimizerNames.ADAMW_HF.value torchdistx_path: Optional[str] = None - lr_scheduler: Optional[SchedulerType] = None + lr_scheduler: Optional[SchedulerType] = "cosine" lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_quadratic_warmup: Optional[bool] = None cosine_min_lr_ratio: Optional[float] = None @@ -458,7 +454,7 @@ class Config: loss_watchdog_threshold: Optional[float] = None loss_watchdog_patience: Optional[int] = None - bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO + bf16: Optional[Union[Literal["auto"], bool]] = "auto" fp16: Optional[bool] = None bfloat16: Optional[bool] = None # for non-AMP cases float16: Optional[bool] = None # for non-AMP cases @@ -472,7 +468,7 @@ class Config: unfrozen_parameters: Optional[List[str]] = None - sequence_len: int = Field(default=1024) + sequence_len: int = Field(default=512) sample_packing: Optional[bool] = None eval_sample_packing: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None @@ -531,10 +527,10 @@ class Config: sample_packing_eff_est: Optional[float] = None axolotl_config_path: Optional[str] = None - is_falcon_derived_model: Optional[bool] = Field(default=False) - is_llama_derived_model: Optional[bool] = Field(default=False) - is_mistral_derived_model: Optional[bool] = Field(default=False) - is_qwen_derived_model: Optional[bool] = Field(default=False) + is_falcon_derived_model: Optional[bool] = Field(default=None) + is_llama_derived_model: Optional[bool] = Field(default=None) + is_mistral_derived_model: Optional[bool] = Field(default=None) + is_qwen_derived_model: Optional[bool] = Field(default=None) @field_validator("datasets", mode="before") @classmethod diff --git a/tests/test_validation.py b/tests/test_validation.py index 70dbc750e6..7a8d80cb75 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -54,6 +54,18 @@ class TestValidation(BaseValidation): Test the validation module """ + def test_defaults(self, minimal_cfg): + test_cfg = DictDefault( + { + "weight_decay": None, + } + | minimal_cfg + ) + cfg = validate_config(test_cfg) + + assert cfg.train_on_inputs is False + assert cfg.weight_decay is None + def test_datasets_min_length(self): cfg = DictDefault( {