Skip to content

Commit

Permalink
Fix Deepspeed loading (#950)
Browse files Browse the repository at this point in the history
* add check for zero3

* freeze parameters

* fixes for deepspeed loading

* fix model parameter check

* unfrozen parameters in example mixtral and logging when unfreezing
  • Loading branch information
winglian authored Dec 13, 2023
1 parent f1f60cb commit 5ea3aa3
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 1 deletion.
39 changes: 39 additions & 0 deletions deepspeed/zero3_bf16.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 0,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 0,
"stage3_max_reuse_distance": 0,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
9 changes: 9 additions & 0 deletions examples/mistral/mixtral.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./qlora-out

## You can optionally freeze the entire model and unfreeze a subset of parameters
unfrozen_parameters:
# - lm_head.*
# - model.embed_tokens.*
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
# - model.layers.3[0-9]+.block_sparse_moe.experts.*

adapter: qlora
lora_model_dir:

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer

Expand Down Expand Up @@ -78,6 +79,9 @@ def train(
)
resume_from_checkpoint = cfg.resume_from_checkpoint

if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)

trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
Expand Down
38 changes: 38 additions & 0 deletions src/axolotl/utils/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
module to freeze/unfreeze parameters by name
"""
import logging
import re

from axolotl.utils.distributed import is_main_process

LOG = logging.getLogger("axolotl.utils.freeze")


def freeze_parameters_except(model, regex_patterns):
"""
Freezes all layers of the given model except for the layers that match given regex patterns.
Periods in the patterns are treated as literal periods, not as wildcard characters.
Parameters:
- model (nn.Module): The PyTorch model to be modified.
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
Returns:
None; the model is modified in place.
"""
# Escape periods and compile the regex patterns
compiled_patterns = [
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
]

# First, freeze all parameters in the model
for param in model.parameters():
param.requires_grad = False

# Unfreeze layers that match the regex patterns
for name, param in model.named_parameters():
if any(pattern.match(name) for pattern in compiled_patterns):
if is_main_process():
LOG.debug(f"unfreezing {name}")
param.requires_grad = True
4 changes: 4 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.deepspeed import is_deepspeed_zero3_enabled

from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
Expand Down Expand Up @@ -285,6 +286,9 @@ def load_model(
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype

if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]

if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.gptq:
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def prepare_optim_env(cfg):
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed


def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
Expand Down

0 comments on commit 5ea3aa3

Please sign in to comment.