diff --git a/library/train_util.py b/library/train_util.py index a35388fee..72b5b24db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4887,7 +4887,11 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - if optimizer_type == "AdamWScheduleFree".lower(): + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): optimizer_class = sf.AdamWScheduleFree logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") elif optimizer_type == "SGDScheduleFree".lower():