-
-
Notifications
You must be signed in to change notification settings - Fork 904
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
381 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
base_model: LnL-AI/dbrx-base-converted | ||
trust_remote_code: true | ||
|
||
load_in_8bit: false | ||
load_in_4bit: false | ||
strict: false | ||
|
||
datasets: | ||
- path: tatsu-lab/alpaca | ||
type: alpaca | ||
dataset_prepared_path: last_run_prepared | ||
val_set_size: 0.0 | ||
output_dir: ./out | ||
|
||
sequence_len: 512 | ||
sample_packing: false | ||
pad_to_sequence_len: false | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
adapter: lora | ||
lora_model_dir: | ||
lora_r: 8 | ||
lora_alpha: 16 | ||
lora_dropout: 0.05 | ||
# w1, w2, & v1 will hang the trainer | ||
lora_target_modules: | ||
- Wqkv # attn | ||
- out_proj # attn | ||
- layer # router | ||
# - w1 | ||
# - w2 | ||
# - v1 | ||
|
||
gradient_accumulation_steps: 1 | ||
micro_batch_size: 1 | ||
num_epochs: 1 | ||
optimizer: paged_adamw_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: auto | ||
fp16: | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: false | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
warmup_steps: 10 | ||
evals_per_epoch: | ||
saves_per_epoch: 1 | ||
debug: | ||
weight_decay: 0.0 | ||
fsdp: | ||
- full_shard | ||
- auto_wrap | ||
fsdp_config: | ||
fsdp_limit_all_gathers: true | ||
fsdp_sync_module_states: true | ||
fsdp_offload_params: false | ||
fsdp_use_orig_params: false | ||
fsdp_cpu_ram_efficient_loading: true | ||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
fsdp_transformer_layer_cls_to_wrap: DbrxBlock | ||
fsdp_state_dict_type: SHARDED_STATE_DICT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# DBRX MoE | ||
|
||
Currently, for LoRA, only the `Wqkv`, `out_proj` and `layer` Linear layers are trainable. | ||
|
||
We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10) | ||
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation | ||
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers | ||
results in the trainer hanging. | ||
|
||
We recommend using the [`LnL-AI/dbrx-base-converted`](https://huggingface.co/LnL-AI/dbrx-base-converted) model as your base model for the time being. | ||
|
||
|
||
- 16-bit LoRA w/ FSDP | ||
- ✅ w/o CPU Offload - 8x80GB uses ~62GiB/gpu | ||
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu | ||
- ❓ 8-bit LoRA w/ FSDP - WIP, need to handle loading 8-bit quantized weights | ||
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu` | ||
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
""" | ||
module to handle loading model on cpu/meta device for FSDP | ||
""" | ||
import os | ||
import time | ||
from typing import List, Optional, Type, Union | ||
|
||
import safetensors | ||
import torch | ||
from accelerate import init_empty_weights | ||
from bitsandbytes.nn import Linear4bit, Params4bit | ||
from fastcore.parallel import parallel | ||
from torch import Tensor, nn | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM | ||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub | ||
|
||
|
||
def _replace_linear( | ||
model: nn.Module, | ||
linear_replacement: Type[nn.Module], | ||
quant_config: Union[dict, None] = None, | ||
skip_modules=None, | ||
**kwargs, | ||
): | ||
""" | ||
Replace linear modules with a new Linear module. | ||
Parameters: | ||
model (`torch.nn.Module`): | ||
Input model or `torch.nn.Module` as the function is run recursively. | ||
linear_replacement (`torch.nn.Module`): | ||
The linear module that replaces the old one. Only expects standard arguments. | ||
If other arguments need to be passed, use a lambda. | ||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`): | ||
List of modules names not to convert. Defaults to `lm_head`. | ||
""" | ||
if skip_modules is None: | ||
skip_modules = ["lm_head"] | ||
for name, module in model.named_children(): | ||
if len(list(module.children())) > 0: | ||
_replace_linear( | ||
module, linear_replacement, quant_config, skip_modules, **kwargs | ||
) | ||
|
||
if isinstance(module, torch.nn.Linear) and name not in skip_modules: | ||
if issubclass(linear_replacement, Linear4bit): | ||
model._modules[ # pylint: disable=protected-access | ||
name | ||
] = linear_replacement( | ||
module.in_features, | ||
module.out_features, | ||
module.bias is not None, | ||
**kwargs, | ||
) | ||
else: | ||
raise ValueError( | ||
f"Unsupported linear replacement: {type(linear_replacement)}" | ||
) | ||
return model | ||
|
||
|
||
def load_and_quantize( | ||
module: nn.Module, | ||
name: str, | ||
value: Tensor, | ||
device: torch.device = None, | ||
dtype: torch.dtype = None, | ||
skip_names: Optional[List[str]] = None, | ||
to_cpu: bool = False, | ||
to_meta: bool = False, | ||
verbose: bool = False, | ||
quant_method: str = "bnb", | ||
): | ||
""" | ||
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`. | ||
Quantizes `Params4bit` on `device` then places on "cpu" if to_cpu=True or "meta" if to_meta=True. | ||
""" | ||
|
||
if not skip_names: | ||
skip_names = [] | ||
|
||
def place_on_device(value): | ||
if to_meta: | ||
device = "meta" | ||
elif to_cpu: | ||
device = "cpu" | ||
return value.to(device=device, dtype=dtype) | ||
|
||
if any(skip_name in name for skip_name in skip_names): | ||
if verbose: | ||
print(f"Skipping {name} because it is in skip_names") | ||
return | ||
|
||
module_key, _, value_key = name.rpartition(".") | ||
try: | ||
submodule = module.get_submodule(module_key) | ||
except AttributeError as exc: | ||
print(f"Module {module_key} not found:\n{exc}") | ||
return | ||
|
||
try: | ||
if quant_method == "bnb": | ||
param = submodule.get_parameter(value_key) | ||
if isinstance(param, Params4bit): | ||
# With `sync_module_states=True`, a meta device Params4bit needs to be the same | ||
# shape as the quantized Params4bit with an initialized quant_state. However, | ||
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This | ||
# workaround quantizes Params4bit to initialize quant_state on all ranks, then | ||
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0. | ||
value = type(param)( | ||
value.to(device=device, dtype=dtype).data, **param.__dict__ | ||
).cuda(device) | ||
if to_meta: | ||
value = type(param)(value.data.to("meta"), **value.__dict__) | ||
elif to_cpu: | ||
value = type(param)(value.data.to("cpu"), **value.__dict__) | ||
else: | ||
value = type(param)(place_on_device(value).data) | ||
|
||
except AttributeError: | ||
# it's a buffer | ||
value = place_on_device(value) | ||
|
||
setattr(submodule, value_key, value) | ||
|
||
|
||
def n_loading_workers(quant_method: str, param_count: float): | ||
devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) | ||
left = int(os.cpu_count() / torch.cuda.device_count()) | ||
model_params_b = 70 | ||
right = int( | ||
(4 if quant_method == "hqq" else 8) | ||
* (devprops.total_memory / 1e9 / 40) | ||
* (model_params_b / (param_count / 1e9)) | ||
) | ||
return min(left, right) | ||
|
||
|
||
def load_sharded_model( | ||
model_name, | ||
model_config, | ||
cfg, | ||
torch_dtype=torch.bfloat16, | ||
low_memory=True, | ||
): | ||
if (low_memory and cfg.local_rank == 0) or not low_memory: | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
use_cache=False, | ||
torch_dtype=torch.float32, | ||
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access | ||
trust_remote_code=cfg.trust_remote_code, | ||
) | ||
dtype = torch.bfloat16 if cfg.bf16 else None | ||
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank) | ||
else: | ||
with init_empty_weights(): | ||
model = AutoModelForCausalLM.from_config( | ||
model_config, | ||
torch_dtype=torch_dtype, | ||
trust_remote_code=cfg.trust_remote_code, | ||
) | ||
if cfg.bf16: | ||
model.to(torch.bfloat16) | ||
return model | ||
|
||
|
||
def load_sharded_model_quant( | ||
model_name, | ||
model_config, | ||
cfg, | ||
compute_dtype=torch.bfloat16, | ||
quant_storage=torch.float32, | ||
low_memory=True, | ||
verbose=False, | ||
loading_workers=2, | ||
): | ||
with init_empty_weights(): | ||
model = AutoModelForCausalLM.from_config( | ||
model_config, | ||
trust_remote_code=cfg.trust_remote_code, | ||
) | ||
if hasattr(model, "transformer"): | ||
model.transformer = _replace_linear( | ||
model.transformer, | ||
Linear4bit, | ||
compute_dtype=compute_dtype, | ||
quant_type="nf4", | ||
quant_storage=quant_storage, | ||
) | ||
else: | ||
# this is the more common case with HF transformers | ||
model.model = _replace_linear( | ||
model.model, | ||
Linear4bit, | ||
compute_dtype=compute_dtype, | ||
quant_type="nf4", | ||
quant_storage=quant_storage, | ||
) | ||
model.is_loaded_in_4bit = True | ||
|
||
# Grab the safetensors files that hold the weights | ||
try: | ||
idx = hub.cached_file(model_name, SAFE_WEIGHTS_INDEX_NAME) | ||
files, _ = hub.get_checkpoint_shard_files(model_name, idx) | ||
except OSError: | ||
try: | ||
# This means the model doesn't have a model.safetensors.index.json because it is not sharded | ||
files = [] | ||
files.append(hub.cached_file(model_name, SAFE_WEIGHTS_NAME)) | ||
except OSError as exc: | ||
# This means the model probably doesn't have a safetensors file | ||
raise exc | ||
|
||
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly | ||
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage | ||
def load_and_quantize_parallel(name_param, model, **kwargs): | ||
name, param = name_param | ||
load_and_quantize(model, name, param, **kwargs) | ||
|
||
quant_method = "bnb" | ||
param_count = sum((p.numel() for n, p in model.named_parameters())) | ||
|
||
n_workers = ( | ||
n_loading_workers(quant_method, param_count) | ||
if loading_workers == -1 | ||
else loading_workers | ||
) | ||
if cfg.local_rank == 0 and verbose: | ||
print(f"Using n_workers: {n_workers} for loading") | ||
|
||
start = time.time() | ||
for filename in tqdm( | ||
files, | ||
desc="Loading & Quantizing Model Shards", | ||
disable=cfg.local_rank != 0, | ||
position=0, | ||
): | ||
weights = safetensors.torch.load_file(filename) | ||
parallel( | ||
load_and_quantize_parallel, | ||
iter(weights.items()), | ||
n_workers=n_workers, | ||
threadpool=True, | ||
model=model, | ||
dtype=quant_storage, | ||
device=cfg.local_rank, | ||
skip_names=[], | ||
to_cpu=(low_memory and cfg.local_rank == 0), | ||
to_meta=(low_memory and cfg.local_rank != 0), | ||
verbose=verbose, | ||
quant_method=quant_method, | ||
) | ||
|
||
if cfg.local_rank == 0 and verbose: | ||
print(f"Loaded model weights in {time.time()-start:.3f} seconds") | ||
# cleanup any extra memory usage from parallel loading | ||
torch.cuda.empty_cache() | ||
|
||
return model |
Oops, something went wrong.