From 7d18d16b7e4cdfe5ebd3489743aaff6250b2b35b Mon Sep 17 00:00:00 2001 From: takuseno Date: Tue, 9 Jan 2024 22:21:06 +0900 Subject: [PATCH] Remove parameter property and add get_parameter --- d3rlpy/algos/qlearning/torch/bear_impl.py | 12 ++++++------ d3rlpy/algos/qlearning/torch/cql_impl.py | 8 ++++---- d3rlpy/algos/qlearning/torch/sac_impl.py | 20 ++++++++++---------- d3rlpy/models/torch/parameters.py | 10 +++------- d3rlpy/models/torch/transformers.py | 10 +++++----- tests/models/test_builders.py | 3 ++- tests/models/torch/test_parameters.py | 6 +++--- 7 files changed, 33 insertions(+), 36 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 64bb6768..f0d1caaf 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -13,6 +13,7 @@ compute_max_with_n_actions_and_indices, compute_vae_error, forward_vae_sample_n, + get_parameter, ) from ....torch_utility import ( TorchMiniBatch, @@ -118,7 +119,7 @@ def compute_actor_loss( temp_loss=loss.temp_loss, temp=loss.temp, mmd_loss=mmd_loss, - alpha=self._modules.log_alpha.parameter.exp(), + alpha=get_parameter(self._modules.log_alpha).exp(), ) def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: @@ -130,7 +131,7 @@ def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor: mmd = self._compute_mmd(obs_t) - alpha = self._modules.log_alpha.parameter.exp() + alpha = get_parameter(self._modules.log_alpha).exp() return (alpha * (mmd - self._alpha_threshold)).mean() def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: @@ -155,7 +156,7 @@ def update_alpha(self, mmd_loss: torch.Tensor) -> None: loss.backward(retain_graph=True) self._modules.alpha_optim.step() # clip for stability - self._modules.log_alpha.data.clamp_(-5.0, 10.0) + get_parameter(self._modules.log_alpha).data.clamp_(-5.0, 10.0) def _compute_mmd(self, x: TorchObservation) -> torch.Tensor: with torch.no_grad(): @@ -231,9 +232,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: batch_size = get_batch_size(batch.observations) max_log_prob = log_probs[torch.arange(batch_size), indices] - return ( - values - self._modules.log_temp.parameter.exp() * max_log_prob - ) + log_temp = get_parameter(self._modules.log_temp) + return values - log_temp.exp() * max_log_prob def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: batch_size = ( diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 8b8cb956..57405cb9 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -11,6 +11,7 @@ DiscreteEnsembleQFunctionForwarder, Parameter, build_squashed_gaussian_distribution, + get_parameter, ) from ....torch_utility import ( TorchMiniBatch, @@ -86,7 +87,7 @@ def compute_critic_loss( return CQLCriticLoss( critic_loss=loss.critic_loss + conservative_loss, conservative_loss=conservative_loss, - alpha=self._modules.log_alpha.parameter.exp(), + alpha=get_parameter(self._modules.log_alpha).exp(), ) def update_alpha(self, conservative_loss: torch.Tensor) -> None: @@ -187,9 +188,8 @@ def _compute_conservative_loss( scaled_loss = self._conservative_weight * loss # clip for stability - clipped_alpha = self._modules.log_alpha.parameter.exp().clamp(0, 1e6)[ - 0 - ][0] + log_alpha = get_parameter(self._modules.log_alpha) + clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0] return clipped_alpha * (scaled_loss - self._alpha_threshold) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index d2de9866..9458c75b 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -16,6 +16,7 @@ Parameter, Policy, build_squashed_gaussian_distribution, + get_parameter, ) from ....torch_utility import Modules, TorchMiniBatch, hard_sync from ....types import Shape, TorchObservation @@ -83,14 +84,14 @@ def compute_actor_loss( 0.0, dtype=torch.float32, device=sampled_action.device ) - entropy = self._modules.log_temp.parameter.exp() * log_prob + entropy = get_parameter(self._modules.log_temp).exp() * log_prob q_t = self._q_func_forwarder.compute_expected_q( batch.observations, sampled_action, "min" ) return SACActorLoss( actor_loss=(entropy - q_t).mean(), temp_loss=temp_loss, - temp=self._modules.log_temp.parameter.exp(), + temp=get_parameter(self._modules.log_temp).exp(), ) def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: @@ -98,7 +99,7 @@ def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: self._modules.temp_optim.zero_grad() with torch.no_grad(): targ_temp = log_prob - self._action_size - loss = -(self._modules.log_temp.parameter.exp() * targ_temp).mean() + loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() loss.backward() self._modules.temp_optim.step() return loss @@ -109,7 +110,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: self._modules.policy(batch.next_observations) ) action, log_prob = dist.sample_with_log_prob() - entropy = self._modules.log_temp.parameter.exp() * log_prob + entropy = get_parameter(self._modules.log_temp).exp() * log_prob target = self._targ_q_func_forwarder.compute_target( batch.next_observations, action, @@ -181,7 +182,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: if self._modules.log_temp is None: temp = torch.zeros_like(log_probs) else: - temp = self._modules.log_temp.parameter.exp() + temp = get_parameter(self._modules.log_temp).exp() entropy = temp * log_probs target = self._targ_q_func_forwarder.compute_target( batch.next_observations @@ -231,7 +232,7 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: if self._modules.log_temp is None: temp = torch.zeros_like(log_probs) else: - temp = self._modules.log_temp.parameter.exp() + temp = get_parameter(self._modules.log_temp).exp() entropy = temp * log_probs return (probs * (entropy - q_t)).sum(dim=1).mean() @@ -248,15 +249,14 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: entropy_target = 0.98 * (-math.log(1 / self.action_size)) targ_temp = expct_log_probs + entropy_target - loss = -(self._modules.log_temp.parameter.exp() * targ_temp).mean() + loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() loss.backward() self._modules.temp_optim.step() # current temperature value - cur_temp = ( - self._modules.log_temp.parameter.exp().cpu().detach().numpy()[0][0] - ) + log_temp = get_parameter(self._modules.log_temp) + cur_temp = log_temp.exp().cpu().detach().numpy()[0][0] return { "temp_loss": float(loss.cpu().detach().numpy()), diff --git a/d3rlpy/models/torch/parameters.py b/d3rlpy/models/torch/parameters.py index 52515369..b85e516b 100644 --- a/d3rlpy/models/torch/parameters.py +++ b/d3rlpy/models/torch/parameters.py @@ -3,7 +3,7 @@ import torch from torch import nn -__all__ = ["Parameter"] +__all__ = ["Parameter", "get_parameter"] class Parameter(nn.Module): # type: ignore @@ -23,10 +23,6 @@ def __call__(self) -> NoReturn: "Parameter does not support __call__. Use parameter property instead." ) - @property - def parameter(self) -> nn.Parameter: - return self._parameter - @property - def data(self) -> torch.Tensor: - return self._parameter.data +def get_parameter(parameter: Parameter) -> nn.Parameter: + return next(parameter.parameters()) diff --git a/d3rlpy/models/torch/transformers.py b/d3rlpy/models/torch/transformers.py index 1358d686..176fc5d6 100644 --- a/d3rlpy/models/torch/transformers.py +++ b/d3rlpy/models/torch/transformers.py @@ -9,7 +9,7 @@ from ...torch_utility import GEGLU from ...types import TorchObservation from .encoders import Encoder -from .parameters import Parameter +from .parameters import Parameter, get_parameter __all__ = [ "ContinuousDecisionTransformer", @@ -194,7 +194,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor: ) # (1, Tmax, N) -> (B, Tmax, N) batched_global_embedding = torch.repeat_interleave( - self._global_position_embedding.parameter, + get_parameter(self._global_position_embedding), batch_size, dim=0, ) @@ -202,7 +202,7 @@ def forward(self, t: torch.Tensor) -> torch.Tensor: global_embedding = torch.gather(batched_global_embedding, 1, last_t) # (1, 3 * Cmax, N) -> (1, T, N) - block_embedding = self._block_position_embedding.parameter[ + block_embedding = get_parameter(self._block_position_embedding)[ :, :context_size, : ] @@ -506,8 +506,8 @@ def forward( ) # add action embedding - embeddings = ( - embeddings + action_masks * self._action_pos_embed.parameter + embeddings = embeddings + action_masks * get_parameter( + self._action_pos_embed ) # (B, T, N) -> (B, T, N) diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index 55aad4ba..226f0e6e 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -23,6 +23,7 @@ from d3rlpy.models.torch import ( ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, + get_parameter, ) from d3rlpy.models.torch.imitators import ConditionalVAE from d3rlpy.models.torch.policies import ( @@ -259,7 +260,7 @@ def test_create_parameter(shape: Sequence[int]) -> None: parameter = create_parameter(shape, x, device="cpu:0") assert len(list(parameter.parameters())) == 1 - assert np.allclose(parameter.data.detach().numpy(), x) + assert np.allclose(get_parameter(parameter).detach().numpy(), x) @pytest.mark.parametrize("observation_shape", [(100,), (4, 84, 84)]) diff --git a/tests/models/torch/test_parameters.py b/tests/models/torch/test_parameters.py index 3102d8f1..934251b8 100644 --- a/tests/models/torch/test_parameters.py +++ b/tests/models/torch/test_parameters.py @@ -3,7 +3,7 @@ import pytest import torch -from d3rlpy.models.torch.parameters import Parameter +from d3rlpy.models.torch.parameters import Parameter, get_parameter @pytest.mark.parametrize("shape", [(100,)]) @@ -11,5 +11,5 @@ def test_parameter(shape: Sequence[int]) -> None: data = torch.rand(shape) parameter = Parameter(data) - assert parameter.data.shape == shape - assert torch.all(parameter.data == data) + assert get_parameter(parameter).data.shape == shape + assert torch.all(get_parameter(parameter).data == data)