Skip to content

Commit

Permalink
Merge branch 'master' into gato
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 9, 2024
2 parents 8f0584f + 7d18d16 commit ac0d536
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 30 deletions.
10 changes: 6 additions & 4 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
compute_max_with_n_actions_and_indices,
compute_vae_error,
forward_vae_sample_n,
get_parameter,
)
from ....torch_utility import (
TorchMiniBatch,
Expand Down Expand Up @@ -118,7 +119,7 @@ def compute_actor_loss(
temp_loss=loss.temp_loss,
temp=loss.temp,
mmd_loss=mmd_loss,
alpha=self._modules.log_alpha().exp(),
alpha=get_parameter(self._modules.log_alpha).exp(),
)

def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
Expand All @@ -130,7 +131,7 @@ def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:

def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:
mmd = self._compute_mmd(obs_t)
alpha = self._modules.log_alpha().exp()
alpha = get_parameter(self._modules.log_alpha).exp()
return (alpha * (mmd - self._alpha_threshold)).mean()

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
Expand All @@ -155,7 +156,7 @@ def update_alpha(self, mmd_loss: torch.Tensor) -> None:
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()
# clip for stability
self._modules.log_alpha.data.clamp_(-5.0, 10.0)
get_parameter(self._modules.log_alpha).data.clamp_(-5.0, 10.0)

def _compute_mmd(self, x: TorchObservation) -> torch.Tensor:
with torch.no_grad():
Expand Down Expand Up @@ -231,7 +232,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
batch_size = get_batch_size(batch.observations)
max_log_prob = log_probs[torch.arange(batch_size), indices]

return values - self._modules.log_temp().exp() * max_log_prob
log_temp = get_parameter(self._modules.log_temp)
return values - log_temp.exp() * max_log_prob

def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:
batch_size = (
Expand Down
6 changes: 4 additions & 2 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DiscreteEnsembleQFunctionForwarder,
Parameter,
build_squashed_gaussian_distribution,
get_parameter,
)
from ....torch_utility import (
TorchMiniBatch,
Expand Down Expand Up @@ -86,7 +87,7 @@ def compute_critic_loss(
return CQLCriticLoss(
critic_loss=loss.critic_loss + conservative_loss,
conservative_loss=conservative_loss,
alpha=self._modules.log_alpha().exp(),
alpha=get_parameter(self._modules.log_alpha).exp(),
)

def update_alpha(self, conservative_loss: torch.Tensor) -> None:
Expand Down Expand Up @@ -187,7 +188,8 @@ def _compute_conservative_loss(
scaled_loss = self._conservative_weight * loss

# clip for stability
clipped_alpha = self._modules.log_alpha().exp().clamp(0, 1e6)[0][0]
log_alpha = get_parameter(self._modules.log_alpha)
clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0]

return clipped_alpha * (scaled_loss - self._alpha_threshold)

Expand Down
18 changes: 10 additions & 8 deletions d3rlpy/algos/qlearning/torch/sac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Parameter,
Policy,
build_squashed_gaussian_distribution,
get_parameter,
)
from ....torch_utility import Modules, TorchMiniBatch, hard_sync
from ....types import Shape, TorchObservation
Expand Down Expand Up @@ -83,22 +84,22 @@ def compute_actor_loss(
0.0, dtype=torch.float32, device=sampled_action.device
)

entropy = self._modules.log_temp().exp() * log_prob
entropy = get_parameter(self._modules.log_temp).exp() * log_prob
q_t = self._q_func_forwarder.compute_expected_q(
batch.observations, sampled_action, "min"
)
return SACActorLoss(
actor_loss=(entropy - q_t).mean(),
temp_loss=temp_loss,
temp=self._modules.log_temp().exp(),
temp=get_parameter(self._modules.log_temp).exp(),
)

def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor:
assert self._modules.temp_optim
self._modules.temp_optim.zero_grad()
with torch.no_grad():
targ_temp = log_prob - self._action_size
loss = -(self._modules.log_temp().exp() * targ_temp).mean()
loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean()
loss.backward()
self._modules.temp_optim.step()
return loss
Expand All @@ -109,7 +110,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
self._modules.policy(batch.next_observations)
)
action, log_prob = dist.sample_with_log_prob()
entropy = self._modules.log_temp().exp() * log_prob
entropy = get_parameter(self._modules.log_temp).exp() * log_prob
target = self._targ_q_func_forwarder.compute_target(
batch.next_observations,
action,
Expand Down Expand Up @@ -181,7 +182,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor:
if self._modules.log_temp is None:
temp = torch.zeros_like(log_probs)
else:
temp = self._modules.log_temp().exp()
temp = get_parameter(self._modules.log_temp).exp()
entropy = temp * log_probs
target = self._targ_q_func_forwarder.compute_target(
batch.next_observations
Expand Down Expand Up @@ -231,7 +232,7 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
if self._modules.log_temp is None:
temp = torch.zeros_like(log_probs)
else:
temp = self._modules.log_temp().exp()
temp = get_parameter(self._modules.log_temp).exp()
entropy = temp * log_probs
return (probs * (entropy - q_t)).sum(dim=1).mean()

Expand All @@ -248,13 +249,14 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]:
entropy_target = 0.98 * (-math.log(1 / self.action_size))
targ_temp = expct_log_probs + entropy_target

loss = -(self._modules.log_temp().exp() * targ_temp).mean()
loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean()

loss.backward()
self._modules.temp_optim.step()

# current temperature value
cur_temp = self._modules.log_temp().exp().cpu().detach().numpy()[0][0]
log_temp = get_parameter(self._modules.log_temp)
cur_temp = log_temp.exp().cpu().detach().numpy()[0][0]

return {
"temp_loss": float(loss.cpu().detach().numpy()),
Expand Down
22 changes: 14 additions & 8 deletions d3rlpy/models/torch/parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import NoReturn

import torch
from torch import nn

__all__ = ["Parameter"]
__all__ = ["Parameter", "get_parameter"]


class Parameter(nn.Module): # type: ignore
Expand All @@ -11,12 +13,16 @@ def __init__(self, data: torch.Tensor):
super().__init__()
self._parameter = nn.Parameter(data)

def forward(self) -> torch.Tensor:
return self._parameter
def forward(self) -> NoReturn:
raise NotImplementedError(
"Parameter does not support __call__. Use parameter property instead."
)

def __call__(self) -> NoReturn:
raise NotImplementedError(
"Parameter does not support __call__. Use parameter property instead."
)

def __call__(self) -> torch.Tensor:
return super().__call__()

@property
def data(self) -> torch.Tensor:
return self._parameter.data
def get_parameter(parameter: Parameter) -> nn.Parameter:
return next(parameter.parameters())
12 changes: 8 additions & 4 deletions d3rlpy/models/torch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ...torch_utility import GEGLU
from ...types import TorchObservation
from .encoders import Encoder
from .parameters import Parameter
from .parameters import Parameter, get_parameter

__all__ = [
"ContinuousDecisionTransformer",
Expand Down Expand Up @@ -194,15 +194,17 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
)
# (1, Tmax, N) -> (B, Tmax, N)
batched_global_embedding = torch.repeat_interleave(
self._global_position_embedding(),
get_parameter(self._global_position_embedding),
batch_size,
dim=0,
)
# (B, Tmax, N) -> (B, 1, N)
global_embedding = torch.gather(batched_global_embedding, 1, last_t)

# (1, 3 * Cmax, N) -> (1, T, N)
block_embedding = self._block_position_embedding()[:, :context_size, :]
block_embedding = get_parameter(self._block_position_embedding)[
:, :context_size, :
]

# (B, 1, N) + (1, T, N) -> (B, T, N)
return global_embedding + block_embedding
Expand Down Expand Up @@ -501,7 +503,9 @@ def forward(
)

# add action embedding
embeddings = embeddings + action_masks * self._action_pos_embed()
embeddings = embeddings + action_masks * get_parameter(
self._action_pos_embed
)

# (B, T, N) -> (B, T, N)
h = self._gpt2(embeddings)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from d3rlpy.models.torch import (
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
get_parameter,
)
from d3rlpy.models.torch.imitators import ConditionalVAE
from d3rlpy.models.torch.policies import (
Expand Down Expand Up @@ -259,7 +260,7 @@ def test_create_parameter(shape: Sequence[int]) -> None:
parameter = create_parameter(shape, x, device="cpu:0")

assert len(list(parameter.parameters())) == 1
assert np.allclose(parameter().detach().numpy(), x)
assert np.allclose(get_parameter(parameter).detach().numpy(), x)


@pytest.mark.parametrize("observation_shape", [(100,), (4, 84, 84)])
Expand Down
6 changes: 3 additions & 3 deletions tests/models/torch/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import pytest
import torch

from d3rlpy.models.torch.parameters import Parameter
from d3rlpy.models.torch.parameters import Parameter, get_parameter


@pytest.mark.parametrize("shape", [(100,)])
def test_parameter(shape: Sequence[int]) -> None:
data = torch.rand(shape)
parameter = Parameter(data)

assert parameter().shape == shape
assert torch.all(parameter() == data)
assert get_parameter(parameter).data.shape == shape
assert torch.all(get_parameter(parameter).data == data)

0 comments on commit ac0d536

Please sign in to comment.