diff --git a/library/train_util.py b/library/train_util.py index 760be33eb..7c945f878 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2975,7 +2975,12 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help=( + "Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, " + "PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, " + "DAdaptLion, DAdaptSGD, AdaFactor, CAME." + ), ) # backward compatibility @@ -3991,7 +3996,9 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, + # PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, + # DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor, CAME." optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -4221,6 +4228,20 @@ def get_optimizer(args, trainable_params): optimizer_kwargs["relative_step"] = True logger.info(f"use Adafactor optimizer | {optimizer_kwargs}") + elif optimizer_type == "CAME".lower(): + logger.info(f"use CAME optimizer | {optimizer_kwargs}") + try: + import came_pytorch + except ImportError: + raise ImportError("No came-pytorch / came-pytorchがインストールされていないようです") + try: + optimizer_class = came_pytorch.CAME + except AttributeError: + raise AttributeError( + "No CAME. Please install came-pytorch. / CAMEが定義されていません。came-pytorchがインストールされていないようです。" + ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_kwargs["relative_step"]: logger.info(f"relative_step is true / relative_stepがtrueです") if lr != 0.0: