Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
fix

fix
  • Loading branch information
flybird11111 committed Dec 11, 2023
1 parent 3f4aeaa commit 987728a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
11 changes: 5 additions & 6 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
import os
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -335,6 +336,7 @@ def enable_lora(
from peft import PeftModel, get_peft_model
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
self.lora_enabled = True
warnings.warn("The hyperparameters of optimizer.params_group[0] will be used for LoRa training. Please check the lr or beta")

if pretrained_dir is None:
peft_model = get_peft_model(model, lora_config)
Expand All @@ -353,12 +355,9 @@ def configure(
if self.lora_enabled:
from peft import PeftModel
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"

optim_params_nums = 0
for param_group in optimizer.param_groups:
optim_params_nums += len(param_group['params'])
model_params_nums = len(list(model.named_parameters()))
assert optim_params_nums == model_params_nums, "Optimizer should be initialized after enabling lora."

optimizer.param_groups = optimizer.param_groups[:1]
optimizer.param_groups[0]['params'] = list(model.parameters())

if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.precision)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/basics/booster_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ We currently provide the following plugins:

- [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism.
- [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp.
- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2.
- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2. The Low Level Zero Plugin supports LoRa training. Please note that when you use 'enable_lora' to transform the model into a LoRa model, the hyperparameters of optimizer.params_groups[0] will be applied to LoRa training. Pay attention to the training parameter settings.
- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management.
- [Hybrid Parallel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/zh-Hans/basics/booster_plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

- [Torch DDP 插件](#torch-ddp-插件): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。
- [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。Low Level Zero 插件支持lora训练,需要注意的是,当您使用'enable_lora'将模型转化为lora模型后,optimizer.pramas_groups[0]的超参数将被使用到lora训练中,请注意训练参数设置。
- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。

Expand Down

0 comments on commit 987728a

Please sign in to comment.