-
-
Notifications
You must be signed in to change notification settings - Fork 902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pydantic 2.x cfg #1239
Pydantic 2.x cfg #1239
Conversation
076d50a
to
d464d1b
Compare
3b032b8
to
50199ea
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my first review. I'll take a look later on the rest.
src/axolotl/cli/__init__.py
Outdated
capabilities = GPUCapabilities( | ||
bf16=is_torch_bf16_gpu_available(), n_gpu=os.environ.get("WORLD_SIZE", 1) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be moved into validate or normalize config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm avoiding setting this in the validation as there are downstream use cases where a user might want to make sure their configuration works on a configured GPU cluster before firing off the training.
@model_validator(mode="before") | ||
@classmethod | ||
def check_sample_packing_w_xformers(cls, root): | ||
if root.get("sample_packing") and root.get("xformers_attention"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a few methods check for when sample_packing is on but not actually active due to unsupported model type?
Pseudo code
def check_sample_packing_active(cls, root):
if root.get("sample_packing") and not any(llama/mistral/qwen):
raise ValueError(
"sample_packing not compatible with current model type"
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know that we can reliably detect the model at this step of the validation to raise a ValueError
ea381f0
to
032eced
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -543,7 +543,7 @@ is_mistral_derived_model: | |||
is_qwen_derived_model: | |||
|
|||
# optional overrides to the base model configuration | |||
model_config: | |||
model_config_overrides: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed this, but we would need to deprecate the old name (add valueerror)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't actually deprecate it with Pydantic because model_config
is an internal variable name for pydantic models
remember to return in validator missing return add missing relora attributes fix test for DictDefault change fix sys template for mistral from fastchat change in PR 2872 fix test for batch size warning
d33a8f7
to
7f688b6
Compare
* WIP conversion to use pydantic for config validation * wip, more fields, add capabilities * wip * update pydantic validation to match existing tests * tweak requirements * setup deprecated paams pydantic model * more validations * wrap up rest of the validations * flesh out the rest of the options from the readme into pydantic * fix model validators as class methods remember to return in validator missing return add missing relora attributes fix test for DictDefault change fix sys template for mistral from fastchat change in PR 2872 fix test for batch size warning * more missing attributes for cfg * updates from PR feedback * fix validation for datasets and pretrain datasets * fix test for lora check
Description
This PR migrates the validation to use Pydantic validators. We keep all of the existing tests with some modification. I've attempted to capture all the
cfg.*
attributes I could find and make sure they are represented in the Pydantic config models. I've also added a GPUCapabilities model to abstract away the underlying hardware checks so that it can be checked offline before sending the config for actual training.For now, we run the validation and convert it back to a DictDefault that we currently use for the
cfg
. This is to minimize the blast radius of this change to strictly validation. We can consider down the line how to swap out the various uses ofcfg
and how the attributes are accessed and if it's compatible with pydantic models.Motivation and Context
How has this been tested?
Existing unit tests.
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)