Skip to content

Commit

Permalink
bump transformers and set roundup_power2_divisions for more VRAM impr…
Browse files Browse the repository at this point in the history
…ovements, low bit ao optimizers (#1769)

* bump transformers and set roundup_power2_divisions for more VRAM improvements

* support for low bit optimizers from torch ao

* fix check for alternate optimizers and use nous models on hf for llama3

* add missing check for ao_adamw_fp8

* fix check when using custom optimizers w adamw
  • Loading branch information
winglian authored Jul 19, 2024
1 parent 7830fe0 commit e4063d6
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 10 deletions.
19 changes: 19 additions & 0 deletions docs/torchao.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
---
title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---

### Installation

Stable Release from the PyTorch index

```bash
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
```


Nightly release

```bash
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
```
2 changes: 1 addition & 1 deletion examples/llama-3/fft-8b.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

Expand Down
2 changes: 1 addition & 1 deletion examples/llama-3/instruct-lora-8b.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
base_model: NousResearch/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

Expand Down
2 changes: 1 addition & 1 deletion examples/llama-3/lora-8b.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

Expand Down
2 changes: 1 addition & 1 deletion examples/llama-3/qlora.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
base_model: meta-llama/Meta-Llama-3-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.42.3
transformers==4.42.4
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
Expand Down
28 changes: 26 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def _wrap_model(self, model, training=True, dataloader=None):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer()

Expand Down Expand Up @@ -356,6 +357,24 @@ def create_optimizer(self):
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8

self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)

if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -1452,7 +1471,12 @@ def build(self, total_num_steps):

trainer_kwargs = {}

if self.cfg.optimizer == "optimi_adamw":
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def train(
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"

# load the tokenizer first
LOG.debug(
Expand Down
13 changes: 11 additions & 2 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,16 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
Expand Down Expand Up @@ -850,7 +859,7 @@ def check_better_transformers(self):
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in self.optimizer.value
not self.optimizer or "adamw" not in str(self.optimizer).lower()
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
Expand Down

0 comments on commit e4063d6

Please sign in to comment.