Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam-mini #1468

Open
wants to merge 23 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def train(args):
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")

else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize, model=flux)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
Expand Down
259 changes: 259 additions & 0 deletions library/adam_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import math
from typing import Iterable, Tuple, Union, Optional

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed._tensor import Replicate


class Adam_mini(torch.optim.Optimizer):
def __init__(
self,
named_parameters: Iterable[Tuple[str, nn.Parameter]],
lr: Union[float, torch.Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0,
*,
model_sharding: bool = True,
dim: int = 2048,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
):
'''
named_parameters: model.named_parameters()

lr: learning rate

betas: same betas as Adam

eps: same eps as Adam

weight_decay: weight_decay coefficient

model_sharding: set to True if you are using model parallelism with more than 1 GPU, including FSDP and zero_1,2,3 in Deepspeed. Set to False if otherwise.

dim: dimension for hidden feature. Could be unspecified if you are training non-transformer models.

n_heads: number of attention heads. Could be unspecified if you are training non-transformer models.

n_kv_heads: number of head for Key and Value. Or equivalently, number of query groups in Group query Attention. Also known as "n_query_groups". If not specified, it will be equal to n_head. Could be unspecified if you are training non-transformer models.
'''
self.dim = dim
self.n_heads = n_heads
if n_kv_heads is not None:
assert n_heads % n_kv_heads == 0, f"{n_heads} {n_kv_heads}"
self.n_kv_heads = n_kv_heads
else:
self.n_kv_heads = n_heads

self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
self.model_sharding = model_sharding
if self.model_sharding:
print("=====>>> Adam-mini is using model_sharding")

if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not self.dim == int(self.dim):
raise ValueError("Invalid dim value: {}".format(self.dim))
if not self.n_heads == int(self.n_heads):
raise ValueError("Invalid n_heads value: {}".format(self.n_heads))
if not self.n_kv_heads == int(self.n_kv_heads):
raise ValueError("Invalid n_kv_heads value: {}".format(self.n_kv_heads))

optim_groups = []
count_embd = count_output = count_wq = count_wk = 0
for param_name, param in named_parameters:
if not param.requires_grad:
continue
print('Adam-mini found the param block with name:', param_name)
state = {}
state["name"] = param_name
state["params"] = param
if "norm" in param_name or "ln_f" in param_name:
state["weight_decay"] = 0.0
else:
state["weight_decay"] = weight_decay
if "embed" in param_name or "wte" in param_name or "embd" in param_name or "layer.weight" in param_name:
count_embd += 1
if "lm_head.weight" in param_name or "output.weight" in param_name or "final_layer.linear.weight" in param_name or "final_layer.adaLN_modulation.1.weight" in param_name:
count_output += 1
if "q_proj.weight" in param_name or "wq.weight" in param_name or "attn_qkv.lora_down" in param_name or "attn_proj.lora_down" in param_name or "attn.qkv.weight" in param_name:
count_wq += 1
assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}"
state["head_numel"] = self.dim * self.dim // self.n_heads
if "k_proj.weight" in param_name or "wk.weight" in param_name or "attn_qkv.lora_up" in param_name or "attn_proj.lora_up" in param_name or "mlp" in param_name or "attn.proj.weight" in param_name:
count_wk += 1
assert (self.dim * self.dim) % self.n_heads == 0, f"{self.dim} {self.n_heads}"
state["head_numel"] = self.dim * self.dim // self.n_heads
optim_groups.append(state)

print(
f'Adam-mini found {count_embd} embedding layers, {count_output} output layers, {count_wq} Querys, {count_wk} Keys.')

Check warning on line 99 in library/adam_mini.py

View workflow job for this annotation

GitHub Actions / build

"Querys" should be "Queries".

if count_embd == 0:
# warning
print(
"=====>>> Warning by Adam-mini: No embedding layer found. If you are training Transformers, please check the name of your embedding layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the name of your embedding layer'). ")
if count_output == 0:
# warning
print(
"=====>>> Warning by Adam-mini: No output layer found. If you are training Transformers (without weight-tying), please check the name of your output layer and manually add them to 'self.embd_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.embd_names.add('the name of your output layer'). Please ignore this warning if you are using weight-tying.")
if count_wq == 0:
# warning
print(
"=====>>> Warning by Adam-mini: No Query found. If you are training Transformers, please check the name of your Query in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.wqk_names.add('the name of your Query'). ")

if count_wk == 0:
# warning
print(
"=====>>> Warning by Adam-mini: No Key found. If you are training Transformers, please check the name of your Key in attention blocks and manually add them to 'self.wqk_names' of Adam-mini. You can do this by adding an additional line of code: optimizer.wqk_names.add('the name of your Key').")

if count_output + count_embd + count_wq + count_wk == 0:
print(
"=====>>> Warning by Adam-mini: you are using default PyTorch partition for Adam-mini. It can cause training instability on large-scale Transformers.")

# embd_blocks, including embd and output layers. Use normal adamW updates for these blocks
self.embd_names = {"embed", "embd", "wte", "lm_head.weight", "output.weight"}
# Query and Keys, will assign lrs by heads
self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"}

defaults = dict(lr=lr, beta1=betas[0], beta2=betas[1], eps=eps)
super().__init__(optim_groups, defaults)

@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
beta1 = group["beta1"]
beta2 = group["beta2"]
lr = group["lr"]
name = group["name"]
eps = group["eps"]

for p in group["params"]:

state = self.state[p]
if any(embd_name in name for embd_name in self.embd_names): # this is for embedding and output layer
if p.grad is None:
continue
if len(state) == 0:
state["m"] = torch.zeros_like(p, dtype=torch.float32)
state["step"] = 0
state["v"] = torch.zeros_like(p, dtype=torch.float32)

grad = p.grad.to(torch.float32)
state["v"].mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
state["step"] += 1
if group["weight_decay"] > 0.0:
p.mul_(1 - lr * group["weight_decay"])
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["step"]
bias_correction_2 = 1 - beta2 ** state["step"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
h = (state["v"].sqrt() / bias_correction_2_sqrt).add_(eps)
stepsize = lr / bias_correction_1
p.addcdiv_(state["m"], h, value=-stepsize)
elif any(wqk_name in name for wqk_name in self.wqk_names): # this is for query and key
if p.grad is None:
continue
head_numel = group["head_numel"]
if len(state) == 0:
m = torch.zeros_like(p, dtype=torch.float32)
state["m"] = m.view(-1, head_numel)
state["head"] = state["m"].size(0)
state["step"] = 0
# NOTE: We must use `zeros_like` for vmean to be a
# DTensor (not `torch.Tensor`) for DTensor parameters.
# state["vmean"] = torch.zeros(state["head"])
state["vmean"] = torch.zeros_like(state["m"][0:state["head"], 0:1])

grad = p.grad.to(torch.float32)
head = state["head"]
grad = grad.view(head, head_numel)
tmp_lr = torch.mean(grad * grad, dim=1, keepdim=True)

state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2)
state["step"] += 1
if group["weight_decay"] > 0.0:
p.mul_(1 - lr * group["weight_decay"])
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["step"]
bias_correction_2 = 1 - beta2 ** state["step"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps)
stepsize = ((1 / bias_correction_1) / h).view(head, 1)
update = (state["m"] * stepsize).view(p.size())
update.mul_(lr)
p.add_(-update)
else: # other blocks
if len(state) == 0:
block_numel = torch.tensor(p.numel(), dtype=torch.float32, device=p.device)
reduced = False
if (self.world_size > 1) and (self.model_sharding is True):
tensor_list = [torch.zeros_like(block_numel) for _ in range(self.world_size)]

dist.all_gather(tensor_list, block_numel)
s = 0
block_numel = 0
for d in tensor_list:
if (d > 0):
s = s + 1
block_numel = block_numel + d
if (s >= 2):
reduced = True

state["m"] = torch.zeros_like(p, dtype=torch.float32)
state["step"] = 0
state["reduced"] = reduced
# NOTE: We must use `new_zeros` for vmean to be a
# DTensor (not `torch.Tensor`) for DTensor parameters.
# state["vmean"] = torch.zeros(1, device=p.device)
# state["vmean"] = p.new_zeros(1)
state["vmean"] = torch.zeros_like(torch.sum(p * p))
state["block_numel"] = block_numel.item()
if p.grad is None:
tmp_lr = torch.zeros_like(torch.sum(p * p))
else:
grad = p.grad.to(torch.float32)
tmp_lr = torch.sum(grad * grad)

if (state["reduced"]):
if "device_mesh" in dir(tmp_lr):
# when tmp_lr is a DTensor in TorchTitan
lr_local = tmp_lr.to_local()
dist.all_reduce(lr_local, op=dist.ReduceOp.SUM)
tmp_lr.redistribute(placements=[Replicate()])
else:
# when tmp_lr is a standard tensor
# print(f"...... dist all reduce.......")
dist.all_reduce(tmp_lr, op=dist.ReduceOp.SUM)

if (p.grad is None):
continue
tmp_lr = tmp_lr / state["block_numel"]

if group["weight_decay"] > 0.0:
p.mul_(1 - lr * group["weight_decay"])
state["step"] += 1
state["m"].lerp_(grad, 1 - beta1)
bias_correction_1 = 1 - beta1 ** state["step"]
bias_correction_2 = 1 - beta2 ** state["step"]
bias_correction_2_sqrt = math.sqrt(bias_correction_2)
state["vmean"].mul_(beta2).add_(tmp_lr, alpha=1 - beta2)
h = (state["vmean"].sqrt() / bias_correction_2_sqrt).add_(eps)
stepsize = (1 / bias_correction_1) / h
update = state["m"] * (stepsize.to(state["m"].device))
update.mul_(lr)
p.add_(-update)
return loss
43 changes: 42 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3250,6 +3250,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリ(ネット接続なしでの学習のため)",
)
parser.add_argument(
"--num_last_block_to_freeze",
type=int,
default=None,
help="num_last_block_to_freeze",
)


def add_optimizer_arguments(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -4272,7 +4278,7 @@ def task():
accelerator.load_state(dirname)


def get_optimizer(args, trainable_params):
def get_optimizer(args, trainable_params, model=None):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"

optimizer_type = args.optimizer_type
Expand Down Expand Up @@ -4545,6 +4551,26 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "AdamMini".lower():
logger.info(f"use AdamMini optimizer | {optimizer_kwargs}")
try:
import library.adam_mini as adam_mini
optimizer_class = adam_mini.Adam_mini
except ImportError:
raise ImportError("No adam-mini / adam-mini がインストールされていないようです")

named_params = [(name, param) for name, param in model.named_parameters() if param.requires_grad]

optimizer_kwargs["dim"] = 722
optimizer_kwargs["n_heads"] = 19

optimizer = optimizer_class(named_params, lr=lr, **optimizer_kwargs)
optimizer.embd_names.add("layer.weight")
optimizer.embd_names.add("final_layer.linear.weight")
optimizer.embd_names.add("final_layer.adaLN_modulation.1.weight")
optimizer.wqk_names.add("attn")
optimizer.wqk_names.add('mlp')

if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
Expand Down Expand Up @@ -5762,6 +5788,21 @@ def sample_image_inference(
pass


def freeze_blocks(model, num_last_block_to_freeze, block_name="x_block"):

filtered_blocks = [(name, param) for name, param in model.named_parameters() if block_name in name]
print(f"filtered_blocks: {len(filtered_blocks)}")

num_blocks_to_freeze = min(len(filtered_blocks), num_last_block_to_freeze)

print(f"freeze_blocks: {num_blocks_to_freeze}")

start_freezing_from = max(0, len(filtered_blocks) - num_blocks_to_freeze)

for i in range(start_freezing_from, len(filtered_blocks)):
_, param = filtered_blocks[i]
param.requires_grad = False

# endregion


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.3
prodigyopt==1.0
lion-pytorch==0.0.6
adam-mini==1.0.1
tensorboard
safetensors==0.4.2
# gradio==3.16.2
Expand Down
9 changes: 8 additions & 1 deletion sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,19 @@ def train(args):
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)

mmdit.requires_grad_(train_mmdit)
if not train_mmdit:
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared

if args.num_last_block_to_freeze:
train_util.freeze_blocks(mmdit,num_last_block_to_freeze=args.num_last_block_to_freeze)

training_models = []
params_to_optimize = []
# if train_unet:
training_models.append(mmdit)
# if block_lrs is None:
params_to_optimize.append({"params": list(mmdit.parameters()), "lr": args.learning_rate})
params_to_optimize.append({"params": list(filter(lambda p: p.requires_grad, mmdit.parameters())), "lr": args.learning_rate})
# else:
# params_to_optimize.extend(get_block_params_to_optimize(mmdit, block_lrs))

Expand Down
4 changes: 2 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def generate_step_logs(

lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
if lr_descriptions is not None and i < len(lr_descriptions):
lr_desc = lr_descriptions[i]
else:
idx = i - (0 if args.network_train_unet_only else -1)
Expand Down Expand Up @@ -471,7 +471,7 @@ def train(self, args):
# v = len(v)
# accelerator.print(f"trainable_params: {k} = {v}")

optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params, network)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
Expand Down
Loading