Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] sweeps #2171

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 123 additions & 9 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging
import random
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional

import click
import yaml

import axolotl
from axolotl.cli.utils import (
Expand All @@ -17,6 +24,76 @@
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig


def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.

Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key

Returns:
list: List of all possible configuration dictionaries

Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}

# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())

# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]

# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)

# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)

# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)

return result_configs


@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
Expand All @@ -43,25 +120,62 @@ def preprocess(config: str, **kwargs):
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option(
"--sweep",
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
def train(config: str, accelerate: bool, sweep: Optional[str] = None, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)

# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)

def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")

else:
from axolotl.cli.train import do_cli

do_cli(config=config, **kwargs)
def iter_configs():
yield config

for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
try:
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli

do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc


@cli.command()
Expand Down
68 changes: 68 additions & 0 deletions tests/cli/test_cli_sweeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs


def test_generate_sweep_configs_no_pairs():
base_config = {
"learning_rate": 0.1,
"micro_batch_size": 1,
"sample_packing": True,
}

sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]}

generate_sweep_configs(base_config, sweeps_config)

assert len(generate_sweep_configs(base_config, sweeps_config)) == 6

cfg_1 = {
"learning_rate": 0.1,
"micro_batch_size": 2,
"weight_decay": 0.0,
"sample_packing": True,
}

assert any(
cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)
)


def test_generate_sweep_configs_with_pairs():
base_config = {
"learning_rate": 0.1,
"micro_batch_size": 1,
"sample_packing": True,
}

sweeps_config = {
"_": [
{
"micro_batch_size": 1,
"gradient_accumulation_steps": 8,
},
{
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
},
{
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
},
{
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
},
],
"weight_decay": [0.0, 0.1],
}

generate_sweep_configs(base_config, sweeps_config)

assert len(generate_sweep_configs(base_config, sweeps_config)) == 8

assert all(
cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8
for cfg in generate_sweep_configs(base_config, sweeps_config)
)
Loading