From bb984031fef4a9cc92a321aa8ceae3c0bbb64419 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 20:11:08 +0900 Subject: [PATCH] Reorganize optimizer tests --- tests/optimizers/__init__.py | 0 tests/optimizers/test_lr_schedulers.py | 44 +++++++++++++++++++ .../{models => optimizers}/test_optimizers.py | 0 3 files changed, 44 insertions(+) create mode 100644 tests/optimizers/__init__.py create mode 100644 tests/optimizers/test_lr_schedulers.py rename tests/{models => optimizers}/test_optimizers.py (100%) diff --git a/tests/optimizers/__init__.py b/tests/optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/optimizers/test_lr_schedulers.py b/tests/optimizers/test_lr_schedulers.py new file mode 100644 index 00000000..bf9ad7c4 --- /dev/null +++ b/tests/optimizers/test_lr_schedulers.py @@ -0,0 +1,44 @@ +import numpy as np +import pytest +import torch +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR + +from d3rlpy.optimizers.lr_schedulers import ( + CosineAnnealingLRFactory, + WarmupSchedulerFactory, +) + + +@pytest.mark.parametrize("warmup_steps", [100]) +@pytest.mark.parametrize("lr", [1e-4]) +@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) +def test_warmup_scheduler_factory( + warmup_steps: int, lr: float, module: torch.nn.Module +) -> None: + factory = WarmupSchedulerFactory(warmup_steps) + + lr_scheduler = factory.create(SGD(module.parameters(), lr=lr)) + + assert np.allclose(lr_scheduler.get_lr()[0], lr / warmup_steps) + for _ in range(warmup_steps): + lr_scheduler.step() + assert lr_scheduler.get_lr()[0] == lr + + assert isinstance(lr_scheduler, LambdaLR) + + # check serialization and deserialization + WarmupSchedulerFactory.deserialize(factory.serialize()) + + +@pytest.mark.parametrize("T_max", [100]) +@pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) +def test_cosine_annealing_factory(T_max: int, module: torch.nn.Module) -> None: + factory = CosineAnnealingLRFactory(T_max=T_max) + + lr_scheduler = factory.create(SGD(module.parameters())) + + assert isinstance(lr_scheduler, CosineAnnealingLR) + + # check serialization and deserialization + CosineAnnealingLRFactory.deserialize(factory.serialize()) diff --git a/tests/models/test_optimizers.py b/tests/optimizers/test_optimizers.py similarity index 100% rename from tests/models/test_optimizers.py rename to tests/optimizers/test_optimizers.py