Skip to content

Commit

Permalink
support galore once upstreamed into transformers (#1409)
Browse files Browse the repository at this point in the history
* support galore once upstreamed into transformers

* update module name for llama in readme and fix typing for all linear

* bump trl for deprecation fixes from newer transformers

* include galore as an extra and install in docker image

* fix optim_args type

* fix optim_args

* update dependencies for galore

* add galore to cicd dockerfile
  • Loading branch information
winglian authored Mar 19, 2024
1 parent 40a88e8 commit dd449c5
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 7 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,26 @@ lr_div_factor: # Learning rate div factor
# - paged_adamw_8bit
# - paged_lion_32bit
# - paged_lion_8bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
# For Galore Optimizers the following optim_args are available
# rank: # type: int
# update_proj_gap # type: int
# scale # type: float
# proj_type: # type: str, default = std

# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
optim_target_modules:
# - self_attn # for llama
# - mlp

# Specify weight decay
weight_decay:
# adamw hyperparams
Expand Down
4 changes: 2 additions & 2 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi

# So we can test the Docker image
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi

# So we can test the Docker image
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.9.0
transformers==4.38.2
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
tokenizers==0.15.0
bitsandbytes>=0.43.0
accelerate==0.26.1
Expand Down Expand Up @@ -39,5 +39,5 @@ s3fs
gcsfs
# adlfs

trl>=0.7.9
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
fastcore>=1.5.29
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@ def parse_requirements():
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
},
)
15 changes: 14 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
Expand All @@ -239,6 +239,7 @@ def create_optimizer(self):
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)

loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
Expand Down Expand Up @@ -1150,6 +1151,18 @@ def build(self, total_num_steps):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,15 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
default=None,
metadata={
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
Expand Down

0 comments on commit dd449c5

Please sign in to comment.