Skip to content

Commit

Permalink
support lora
Browse files Browse the repository at this point in the history
fix

fix
  • Loading branch information
flybird11111 committed Oct 21, 2024
1 parent a6ac181 commit fae90d6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 5 additions & 1 deletion applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
Expand All @@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)

booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)

booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
Expand Down
13 changes: 11 additions & 2 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -166,7 +167,7 @@ def train(args) -> None:
# ======================================================
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
Expand All @@ -178,11 +179,16 @@ def train(args) -> None:
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)

if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)

# this is essential, otherwise the grad checkpoint will not work.
model.train()

if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

model_numel = get_model_numel(model)
Expand Down Expand Up @@ -319,6 +325,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
Expand Down Expand Up @@ -389,6 +396,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
Expand Down Expand Up @@ -514,6 +522,7 @@ def train(args) -> None:
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")

# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
Expand Down

0 comments on commit fae90d6

Please sign in to comment.