From 8e378cf03df645cef897a342559dc5fa7f66a35d Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Wed, 11 Dec 2024 19:43:44 +0900 Subject: [PATCH] add RAdamScheduleFree support --- library/train_util.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a35388fee..72b5b24db 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4887,7 +4887,11 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") - if optimizer_type == "AdamWScheduleFree".lower(): + + if optimizer_type == "RAdamScheduleFree".lower(): + optimizer_class = sf.RAdamScheduleFree + logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "AdamWScheduleFree".lower(): optimizer_class = sf.AdamWScheduleFree logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") elif optimizer_type == "SGDScheduleFree".lower():