Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
DimensionSTP committed Jul 31, 2024
2 parents 4f1c616 + e7e26ef commit 83d5ed4
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 28 deletions.
4 changes: 2 additions & 2 deletions configs/architecture/rhythm_architecture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ model:
strategy: ${strategy}
lr: ${lr}
weight_decay: ${weight_decay}
period: ${period}
eta_min: ${eta_min}
half_period: ${half_period}
eta_min_rate: ${eta_min_rate}
interval: step
connected_dir: ${connected_dir}
4 changes: 2 additions & 2 deletions configs/rhythm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ direction: bi

lr: 0.0001
weight_decay: 0.01
period: 2
eta_min: 0.00001
half_period: 1
eta_min_rate: 0.1

monitor: val_rmse_loss
tracking_direction: min
Expand Down
8 changes: 4 additions & 4 deletions configs/tuner/rhythm_tuner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ hparams:
low: 0.001
high: 0.01
log: False
period:
half_period:
low: 1
high: 10
log: False
eta_min:
low: 0.00005
high: 0.00001
eta_min_rate:
low: 0.05
high: 0.5
log: False

module_params:
Expand Down
13 changes: 7 additions & 6 deletions src/architectures/rhythm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def __init__(
strategy: str,
lr: float,
weight_decay: float,
period: int,
eta_min: float,
half_period: int,
eta_min_rate: float,
interval: str,
connected_dir: str,
) -> None:
Expand All @@ -27,8 +27,8 @@ def __init__(
self.strategy = strategy
self.lr = lr
self.weight_decay = weight_decay
self.period = period
self.eta_min = eta_min
self.half_period = half_period
self.eta_min_rate = eta_min_rate
self.interval = interval
self.connected_dir = connected_dir

Expand Down Expand Up @@ -86,11 +86,12 @@ def configure_optimizers(self) -> Dict[str, Any]:
lr=self.lr,
weight_decay=self.weight_decay,
)
t_max = self.period * self.trainer.num_training_batches
t_max = self.half_period * self.trainer.num_training_batches
eta_min = self.lr * self.eta_min_rate
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer,
T_max=t_max,
eta_min=self.eta_min,
eta_min=eta_min,
)
return {
"optimizer": optimizer,
Expand Down
28 changes: 14 additions & 14 deletions src/tuners/rhythm_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ def optuna_objective(
high=self.hparams.weight_decay.high,
log=self.hparams.weight_decay.log,
)
if self.hparams.period:
params["period"] = trial.suggest_int(
name="period",
low=self.hparams.period.low,
high=self.hparams.period.high,
log=self.hparams.period.log,
if self.hparams.half_period:
params["half_period"] = trial.suggest_int(
name="half_period",
low=self.hparams.half_period.low,
high=self.hparams.half_period.high,
log=self.hparams.half_period.log,
)
if self.hparams.eta_min:
params["eta_min"] = trial.suggest_float(
name="eta_min",
low=self.hparams.eta_min.low,
high=self.hparams.eta_min.high,
log=self.hparams.eta_min.log,
if self.hparams.eta_min_rate:
params["eta_min_rate"] = trial.suggest_float(
name="eta_min_rate",
low=self.hparams.eta_min_rate.low,
high=self.hparams.eta_min_rate.high,
log=self.hparams.eta_min_rate.log,
)

model = CustomizedRhythmNet(
Expand All @@ -143,8 +143,8 @@ def optuna_objective(
strategy=self.module_params.strategy,
lr=params["lr"],
weight_decay=params["weight_decay"],
period=params["period"],
eta_min=params["eta_min"],
half_period=params["half_period"],
eta_min_rate=params["eta_min_rate"],
interval=self.module_params.interval,
connected_dir=self.module_params.connected_dir,
)
Expand Down

0 comments on commit 83d5ed4

Please sign in to comment.