Skip to content

Commit

Permalink
set cpu_offload: false to reduce vram, constrain new accleerator logi…
Browse files Browse the repository at this point in the history
…c to qlora + fsdp
  • Loading branch information
winglian committed Mar 8, 2024
1 parent 716133c commit 5d072dd
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -477,6 +481,11 @@ def push_to_hub(self, *args, **kwargs) -> str:
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()

if self.args.qlora is False:
return res

# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)
Expand Down Expand Up @@ -504,6 +513,7 @@ def create_accelerator_and_postprocess(self):
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
Expand Down Expand Up @@ -836,6 +846,9 @@ def build(self, total_num_steps):
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)

if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True

# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
Expand Down

0 comments on commit 5d072dd

Please sign in to comment.