Skip to content

Commit

Permalink
Remove parameter property and add get_parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 9, 2024
1 parent 9467793 commit 7d18d16
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 36 deletions.
12 changes: 6 additions & 6 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
compute_max_with_n_actions_and_indices,
compute_vae_error,
forward_vae_sample_n,
get_parameter,
)
from ....torch_utility import (
TorchMiniBatch,
Expand Down Expand Up @@ -118,7 +119,7 @@ def compute_actor_loss(
temp_loss=loss.temp_loss,
temp=loss.temp,
mmd_loss=mmd_loss,
alpha=self._modules.log_alpha.parameter.exp(),
alpha=get_parameter(self._modules.log_alpha).exp(),
)

def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
Expand All @@ -130,7 +131,7 @@ def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:

def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:
mmd = self._compute_mmd(obs_t)
alpha = self._modules.log_alpha.parameter.exp()
alpha = get_parameter(self._modules.log_alpha).exp()
return (alpha * (mmd - self._alpha_threshold)).mean()

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
Expand All @@ -155,7 +156,7 @@ def update_alpha(self, mmd_loss: torch.Tensor) -> None:
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()
# clip for stability
self._modules.log_alpha.data.clamp_(-5.0, 10.0)
get_parameter(self._modules.log_alpha).data.clamp_(-5.0, 10.0)

def _compute_mmd(self, x: TorchObservation) -> torch.Tensor:
with torch.no_grad():
Expand Down Expand Up @@ -231,9 +232,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
batch_size = get_batch_size(batch.observations)
max_log_prob = log_probs[torch.arange(batch_size), indices]

return (
values - self._modules.log_temp.parameter.exp() * max_log_prob
)
log_temp = get_parameter(self._modules.log_temp)
return values - log_temp.exp() * max_log_prob

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
batch_size = (
Expand Down
8 changes: 4 additions & 4 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DiscreteEnsembleQFunctionForwarder,
Parameter,
build_squashed_gaussian_distribution,
get_parameter,
)
from ....torch_utility import (
TorchMiniBatch,
Expand Down Expand Up @@ -86,7 +87,7 @@ def compute_critic_loss(
return CQLCriticLoss(
critic_loss=loss.critic_loss + conservative_loss,
conservative_loss=conservative_loss,
alpha=self._modules.log_alpha.parameter.exp(),
alpha=get_parameter(self._modules.log_alpha).exp(),
)

def update_alpha(self, conservative_loss: torch.Tensor) -> None:
Expand Down Expand Up @@ -187,9 +188,8 @@ def _compute_conservative_loss(
scaled_loss = self._conservative_weight * loss

# clip for stability
clipped_alpha = self._modules.log_alpha.parameter.exp().clamp(0, 1e6)[
0
][0]
log_alpha = get_parameter(self._modules.log_alpha)
clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0]

return clipped_alpha * (scaled_loss - self._alpha_threshold)

Expand Down
20 changes: 10 additions & 10 deletions d3rlpy/algos/qlearning/torch/sac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Parameter,
Policy,
build_squashed_gaussian_distribution,
get_parameter,
)
from ....torch_utility import Modules, TorchMiniBatch, hard_sync
from ....types import Shape, TorchObservation
Expand Down Expand Up @@ -83,22 +84,22 @@ def compute_actor_loss(
0.0, dtype=torch.float32, device=sampled_action.device
)

entropy = self._modules.log_temp.parameter.exp() * log_prob
entropy = get_parameter(self._modules.log_temp).exp() * log_prob
q_t = self._q_func_forwarder.compute_expected_q(
batch.observations, sampled_action, "min"
)
return SACActorLoss(
actor_loss=(entropy - q_t).mean(),
temp_loss=temp_loss,
temp=self._modules.log_temp.parameter.exp(),
temp=get_parameter(self._modules.log_temp).exp(),
)

def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor:
assert self._modules.temp_optim
self._modules.temp_optim.zero_grad()
with torch.no_grad():
targ_temp = log_prob - self._action_size
loss = -(self._modules.log_temp.parameter.exp() * targ_temp).mean()
loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean()
loss.backward()
self._modules.temp_optim.step()
return loss
Expand All @@ -109,7 +110,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
self._modules.policy(batch.next_observations)
)
action, log_prob = dist.sample_with_log_prob()
entropy = self._modules.log_temp.parameter.exp() * log_prob
entropy = get_parameter(self._modules.log_temp).exp() * log_prob
target = self._targ_q_func_forwarder.compute_target(
batch.next_observations,
action,
Expand Down Expand Up @@ -181,7 +182,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
if self._modules.log_temp is None:
temp = torch.zeros_like(log_probs)
else:
temp = self._modules.log_temp.parameter.exp()
temp = get_parameter(self._modules.log_temp).exp()
entropy = temp * log_probs
target = self._targ_q_func_forwarder.compute_target(
batch.next_observations
Expand Down Expand Up @@ -231,7 +232,7 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
if self._modules.log_temp is None:
temp = torch.zeros_like(log_probs)
else:
temp = self._modules.log_temp.parameter.exp()
temp = get_parameter(self._modules.log_temp).exp()
entropy = temp * log_probs
return (probs * (entropy - q_t)).sum(dim=1).mean()

Expand All @@ -248,15 +249,14 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]:
entropy_target = 0.98 * (-math.log(1 / self.action_size))
targ_temp = expct_log_probs + entropy_target

loss = -(self._modules.log_temp.parameter.exp() * targ_temp).mean()
loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean()

loss.backward()
self._modules.temp_optim.step()

# current temperature value
cur_temp = (
self._modules.log_temp.parameter.exp().cpu().detach().numpy()[0][0]
)
log_temp = get_parameter(self._modules.log_temp)
cur_temp = log_temp.exp().cpu().detach().numpy()[0][0]

return {
"temp_loss": float(loss.cpu().detach().numpy()),
Expand Down
10 changes: 3 additions & 7 deletions d3rlpy/models/torch/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import nn

__all__ = ["Parameter"]
__all__ = ["Parameter", "get_parameter"]


class Parameter(nn.Module): # type: ignore
Expand All @@ -23,10 +23,6 @@ def __call__(self) -> NoReturn:
"Parameter does not support __call__. Use parameter property instead."
)

@property
def parameter(self) -> nn.Parameter:
return self._parameter

@property
def data(self) -> torch.Tensor:
return self._parameter.data
def get_parameter(parameter: Parameter) -> nn.Parameter:
return next(parameter.parameters())
10 changes: 5 additions & 5 deletions d3rlpy/models/torch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...torch_utility import GEGLU
from ...types import TorchObservation
from .encoders import Encoder
from .parameters import Parameter
from .parameters import Parameter, get_parameter

__all__ = [
"ContinuousDecisionTransformer",
Expand Down Expand Up @@ -194,15 +194,15 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
)
# (1, Tmax, N) -> (B, Tmax, N)
batched_global_embedding = torch.repeat_interleave(
self._global_position_embedding.parameter,
get_parameter(self._global_position_embedding),
batch_size,
dim=0,
)
# (B, Tmax, N) -> (B, 1, N)
global_embedding = torch.gather(batched_global_embedding, 1, last_t)

# (1, 3 * Cmax, N) -> (1, T, N)
block_embedding = self._block_position_embedding.parameter[
block_embedding = get_parameter(self._block_position_embedding)[
:, :context_size, :
]

Expand Down Expand Up @@ -506,8 +506,8 @@ def forward(
)

# add action embedding
embeddings = (
embeddings + action_masks * self._action_pos_embed.parameter
embeddings = embeddings + action_masks * get_parameter(
self._action_pos_embed
)

# (B, T, N) -> (B, T, N)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from d3rlpy.models.torch import (
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
get_parameter,
)
from d3rlpy.models.torch.imitators import ConditionalVAE
from d3rlpy.models.torch.policies import (
Expand Down Expand Up @@ -259,7 +260,7 @@ def test_create_parameter(shape: Sequence[int]) -> None:
parameter = create_parameter(shape, x, device="cpu:0")

assert len(list(parameter.parameters())) == 1
assert np.allclose(parameter.data.detach().numpy(), x)
assert np.allclose(get_parameter(parameter).detach().numpy(), x)


@pytest.mark.parametrize("observation_shape", [(100,), (4, 84, 84)])
Expand Down
6 changes: 3 additions & 3 deletions tests/models/torch/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import pytest
import torch

from d3rlpy.models.torch.parameters import Parameter
from d3rlpy.models.torch.parameters import Parameter, get_parameter


@pytest.mark.parametrize("shape", [(100,)])
def test_parameter(shape: Sequence[int]) -> None:
data = torch.rand(shape)
parameter = Parameter(data)

assert parameter.data.shape == shape
assert torch.all(parameter.data == data)
assert get_parameter(parameter).data.shape == shape
assert torch.all(get_parameter(parameter).data == data)

0 comments on commit 7d18d16

Please sign in to comment.