Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support galore once upstreamed into transformers #1409

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,26 @@ lr_div_factor: # Learning rate div factor
# - paged_adamw_8bit
# - paged_lion_32bit
# - paged_lion_8bit
# - galore_adamw
# - galore_adamw_8bit
# - galore_adafactor
# - galore_adamw_layerwise
# - galore_adamw_8bit_layerwise
# - galore_adafactor_layerwise
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:
# For Galore Optimizers the following optim_args are available
# rank: # type: int
# update_proj_gap # type: int
# scale # type: float
# proj_type: # type: str, default = std

# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
optim_target_modules:
# - self_attn # for llama
# - mlp

# Specify weight decay
weight_decay:
# adamw hyperparams
Expand Down
4 changes: 2 additions & 2 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi

# So we can test the Docker image
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
fi

# So we can test the Docker image
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.9.0
transformers==4.38.2
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
tokenizers==0.15.0
bitsandbytes>=0.43.0
accelerate==0.26.1
Expand Down Expand Up @@ -39,5 +39,5 @@ s3fs
gcsfs
# adlfs

trl>=0.7.9
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
fastcore>=1.5.29
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,8 @@ def parse_requirements():
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
},
)
15 changes: 14 additions & 1 deletion src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
Expand All @@ -232,6 +232,7 @@ def create_optimizer(self):
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)

loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
Expand Down Expand Up @@ -1016,6 +1017,18 @@ def build(self, total_num_steps):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
if self.cfg.optim_args:
if isinstance(self.cfg.optim_args, dict):
optim_args = ",".join(
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
)
else:
optim_args = self.cfg.optim_args
training_arguments_kwargs["optim_args"] = optim_args
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
Expand Down
9 changes: 9 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
default=None,
metadata={
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
Expand Down
Loading