From 7b48c17ee99fe688da6cfde6109c07df2961fa55 Mon Sep 17 00:00:00 2001 From: DimensionSTP Date: Sat, 20 Jul 2024 06:10:17 +0900 Subject: [PATCH] feat: weight decay hparams tuning --- src/tuners/rhythm_tuner.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tuners/rhythm_tuner.py b/src/tuners/rhythm_tuner.py index d93ae59..f488b25 100644 --- a/src/tuners/rhythm_tuner.py +++ b/src/tuners/rhythm_tuner.py @@ -109,6 +109,13 @@ def optuna_objective( high=self.hparams.lr.high, log=self.hparams.lr.log, ) + if self.hparams.weight_decay: + params["weight_decay"] = trial.suggest_float( + name="weight_decay", + low=self.hparams.weight_decay.low, + high=self.hparams.weight_decay.high, + log=self.hparams.weight_decay.log, + ) if self.hparams.period: params["period"] = trial.suggest_int( name="period", @@ -135,6 +142,7 @@ def optuna_objective( model=model, strategy=self.module_params.strategy, lr=params["lr"], + weight_decay=params["weight_decay"], period=params["period"], eta_min=params["eta_min"], interval=self.module_params.interval,