diff --git a/configs/architecture/rhythm_architecture.yaml b/configs/architecture/rhythm_architecture.yaml index 95db5ed..72daa69 100644 --- a/configs/architecture/rhythm_architecture.yaml +++ b/configs/architecture/rhythm_architecture.yaml @@ -11,7 +11,7 @@ model: strategy: ${strategy} lr: ${lr} weight_decay: ${weight_decay} -period: ${period} -eta_min: ${eta_min} +half_period: ${half_period} +eta_min_rate: ${eta_min_rate} interval: step connected_dir: ${connected_dir} diff --git a/configs/rhythm.yaml b/configs/rhythm.yaml index 6808152..ac5f41e 100644 --- a/configs/rhythm.yaml +++ b/configs/rhythm.yaml @@ -32,8 +32,8 @@ direction: bi lr: 0.0001 weight_decay: 0.01 -period: 2 -eta_min: 0.00001 +half_period: 1 +eta_min_rate: 0.1 monitor: val_rmse_loss tracking_direction: min diff --git a/configs/tuner/rhythm_tuner.yaml b/configs/tuner/rhythm_tuner.yaml index 5ac22ef..fb2acb4 100644 --- a/configs/tuner/rhythm_tuner.yaml +++ b/configs/tuner/rhythm_tuner.yaml @@ -28,13 +28,13 @@ hparams: low: 0.001 high: 0.01 log: False - period: + half_period: low: 1 high: 10 log: False - eta_min: - low: 0.00005 - high: 0.00001 + eta_min_rate: + low: 0.05 + high: 0.5 log: False module_params: diff --git a/src/architectures/rhythm_architecture.py b/src/architectures/rhythm_architecture.py index 6ae2636..87edd41 100644 --- a/src/architectures/rhythm_architecture.py +++ b/src/architectures/rhythm_architecture.py @@ -17,8 +17,8 @@ def __init__( strategy: str, lr: float, weight_decay: float, - period: int, - eta_min: float, + half_period: int, + eta_min_rate: float, interval: str, connected_dir: str, ) -> None: @@ -27,8 +27,8 @@ def __init__( self.strategy = strategy self.lr = lr self.weight_decay = weight_decay - self.period = period - self.eta_min = eta_min + self.half_period = half_period + self.eta_min_rate = eta_min_rate self.interval = interval self.connected_dir = connected_dir @@ -86,11 +86,12 @@ def configure_optimizers(self) -> Dict[str, Any]: lr=self.lr, weight_decay=self.weight_decay, ) - t_max = self.period * self.trainer.num_training_batches + t_max = self.half_period * self.trainer.num_training_batches + eta_min = self.lr * self.eta_min_rate scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=t_max, - eta_min=self.eta_min, + eta_min=eta_min, ) return { "optimizer": optimizer, diff --git a/src/tuners/rhythm_tuner.py b/src/tuners/rhythm_tuner.py index f488b25..d14d912 100644 --- a/src/tuners/rhythm_tuner.py +++ b/src/tuners/rhythm_tuner.py @@ -116,19 +116,19 @@ def optuna_objective( high=self.hparams.weight_decay.high, log=self.hparams.weight_decay.log, ) - if self.hparams.period: - params["period"] = trial.suggest_int( - name="period", - low=self.hparams.period.low, - high=self.hparams.period.high, - log=self.hparams.period.log, + if self.hparams.half_period: + params["half_period"] = trial.suggest_int( + name="half_period", + low=self.hparams.half_period.low, + high=self.hparams.half_period.high, + log=self.hparams.half_period.log, ) - if self.hparams.eta_min: - params["eta_min"] = trial.suggest_float( - name="eta_min", - low=self.hparams.eta_min.low, - high=self.hparams.eta_min.high, - log=self.hparams.eta_min.log, + if self.hparams.eta_min_rate: + params["eta_min_rate"] = trial.suggest_float( + name="eta_min_rate", + low=self.hparams.eta_min_rate.low, + high=self.hparams.eta_min_rate.high, + log=self.hparams.eta_min_rate.log, ) model = CustomizedRhythmNet( @@ -143,8 +143,8 @@ def optuna_objective( strategy=self.module_params.strategy, lr=params["lr"], weight_decay=params["weight_decay"], - period=params["period"], - eta_min=params["eta_min"], + half_period=params["half_period"], + eta_min_rate=params["eta_min_rate"], interval=self.module_params.interval, connected_dir=self.module_params.connected_dir, )