Skip to content

Commit

Permalink
Remove ConditionalVAE
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 13, 2024
1 parent 69f12d6 commit ee65b1f
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 184 deletions.
23 changes: 17 additions & 6 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from ...constants import ActionSpace
from ...models.builders import (
create_categorical_policy,
create_conditional_vae,
create_continuous_q_function,
create_deterministic_residual_policy,
create_discrete_q_function,
create_vae_decoder,
create_vae_encoder,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
Expand Down Expand Up @@ -200,7 +201,7 @@ def inner_create_impl(
n_ensembles=self._config.n_critics,
device=self._device,
)
imitator = create_conditional_vae(
vae_encoder = create_vae_encoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
Expand All @@ -209,26 +210,36 @@ def inner_create_impl(
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)
vae_decoder = create_vae_decoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.imitator_optim_factory.create(
imitator.named_modules(), lr=self._config.imitator_learning_rate
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
)

modules = BCQModules(
policy=policy,
targ_policy=targ_policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
imitator=imitator,
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
actor_optim=actor_optim,
critic_optim=critic_optim,
imitator_optim=imitator_optim,
vae_optim=vae_optim,
)

self._impl = BCQImpl(
Expand Down
23 changes: 17 additions & 6 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...models.builders import (
create_conditional_vae,
create_continuous_q_function,
create_normal_policy,
create_parameter,
create_vae_decoder,
create_vae_encoder,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
Expand Down Expand Up @@ -178,7 +179,7 @@ def inner_create_impl(
n_ensembles=self._config.n_critics,
device=self._device,
)
imitator = create_conditional_vae(
vae_encoder = create_vae_encoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
Expand All @@ -187,6 +188,13 @@ def inner_create_impl(
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)
vae_decoder = create_vae_decoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)
log_temp = create_parameter(
(1, 1),
math.log(self._config.initial_temperature),
Expand All @@ -202,8 +210,10 @@ def inner_create_impl(
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.imitator_optim_factory.create(
imitator.named_modules(), lr=self._config.imitator_learning_rate
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
Expand All @@ -216,12 +226,13 @@ def inner_create_impl(
policy=policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
imitator=imitator,
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
log_temp=log_temp,
log_alpha=log_alpha,
actor_optim=actor_optim,
critic_optim=critic_optim,
imitator_optim=imitator_optim,
vae_optim=vae_optim,
temp_optim=temp_optim,
alpha_optim=alpha_optim,
)
Expand Down
43 changes: 32 additions & 11 deletions d3rlpy/algos/qlearning/plas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from ...base import DeviceArg, LearnableConfig, register_learnable
from ...constants import ActionSpace
from ...models.builders import (
create_conditional_vae,
create_continuous_q_function,
create_deterministic_policy,
create_deterministic_residual_policy,
create_vae_decoder,
create_vae_encoder,
)
from ...models.encoders import EncoderFactory, make_encoder_field
from ...models.optimizers import OptimizerFactory, make_optimizer_field
Expand Down Expand Up @@ -135,7 +136,7 @@ def inner_create_impl(
n_ensembles=self._config.n_critics,
device=self._device,
)
imitator = create_conditional_vae(
vae_encoder = create_vae_encoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
Expand All @@ -144,26 +145,36 @@ def inner_create_impl(
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)
vae_decoder = create_vae_decoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.critic_optim_factory.create(
imitator.named_modules(), lr=self._config.imitator_learning_rate
vae_optim = self._config.critic_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
)

modules = PLASModules(
policy=policy,
targ_policy=targ_policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
imitator=imitator,
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
actor_optim=actor_optim,
critic_optim=critic_optim,
imitator_optim=imitator_optim,
vae_optim=vae_optim,
)

self._impl = PLASImpl(
Expand Down Expand Up @@ -272,7 +283,7 @@ def inner_create_impl(
n_ensembles=self._config.n_critics,
device=self._device,
)
imitator = create_conditional_vae(
vae_encoder = create_vae_encoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
Expand All @@ -281,6 +292,13 @@ def inner_create_impl(
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)
vae_decoder = create_vae_decoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
)
perturbation = create_deterministic_residual_policy(
observation_shape=observation_shape,
action_size=action_size,
Expand All @@ -304,21 +322,24 @@ def inner_create_impl(
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
)
imitator_optim = self._config.critic_optim_factory.create(
imitator.named_modules(), lr=self._config.imitator_learning_rate
vae_optim = self._config.critic_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
)

modules = PLASWithPerturbationModules(
policy=policy,
targ_policy=targ_policy,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
imitator=imitator,
vae_encoder=vae_encoder,
vae_decoder=vae_decoder,
perturbation=perturbation,
targ_perturbation=targ_perturbation,
actor_optim=actor_optim,
critic_optim=critic_optim,
imitator_optim=imitator_optim,
vae_optim=vae_optim,
)

self._impl = PLASWithPerturbationImpl(
Expand Down
27 changes: 12 additions & 15 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from ....models.torch import (
ActionOutput,
CategoricalPolicy,
ConditionalVAE,
ContinuousEnsembleQFunctionForwarder,
DeterministicResidualPolicy,
DiscreteEnsembleQFunctionForwarder,
VAEDecoder,
VAEEncoder,
compute_discrete_imitation_loss,
compute_max_with_n_actions,
compute_vae_error,
forward_vae_decode,
)
from ....torch_utility import (
TorchMiniBatch,
Expand All @@ -42,8 +42,9 @@
class BCQModules(DDPGBaseModules):
policy: DeterministicResidualPolicy
targ_policy: DeterministicResidualPolicy
imitator: ConditionalVAE
imitator_optim: Optimizer
vae_encoder: VAEEncoder
vae_decoder: VAEDecoder
vae_optim: Optimizer


class BCQImpl(DDPGBaseImpl):
Expand Down Expand Up @@ -95,16 +96,17 @@ def compute_actor_loss(
return DDPGBaseActorLoss(-value[0].mean())

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.imitator_optim.zero_grad()
self._modules.vae_optim.zero_grad()
loss = compute_vae_error(
vae=self._modules.imitator,
vae_encoder=self._modules.vae_encoder,
vae_decoder=self._modules.vae_decoder,
x=batch.observations,
action=batch.actions,
beta=self._beta,
)
loss.backward()
self._modules.imitator_optim.step()
return {"imitator_loss": float(loss.cpu().detach().numpy())}
self._modules.vae_optim.step()
return {"vae_loss": float(loss.cpu().detach().numpy())}

def _repeat_observation(self, x: TorchObservation) -> TorchObservation:
# (batch_size, *obs_shape) -> (batch_size, n, *obs_shape)
Expand All @@ -126,11 +128,7 @@ def _sample_repeated_action(
)
clipped_latent = latent.clamp(-0.5, 0.5)
# sample action
sampled_action = forward_vae_decode(
vae=self._modules.imitator,
x=flattened_x,
latent=clipped_latent,
)
sampled_action = self._modules.vae_decoder(flattened_x, clipped_latent)
# add residual action
policy = self._modules.targ_policy if target else self._modules.policy
action = policy(flattened_x, sampled_action)
Expand Down Expand Up @@ -196,8 +194,7 @@ def inner_update(
batch_size, 2 * self._action_size, device=self._device
)
clipped_latent = latent.clamp(-0.5, 0.5)
sampled_action = forward_vae_decode(
vae=self._modules.imitator,
sampled_action = self._modules.vae_decoder(
x=batch.observations,
latent=clipped_latent,
)
Expand Down
22 changes: 13 additions & 9 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from ....models.torch import (
ActionOutput,
ConditionalVAE,
ContinuousEnsembleQFunctionForwarder,
Parameter,
VAEDecoder,
VAEEncoder,
build_squashed_gaussian_distribution,
compute_max_with_n_actions_and_indices,
compute_vae_error,
Expand Down Expand Up @@ -43,9 +44,10 @@ def _laplacian_kernel(

@dataclasses.dataclass(frozen=True)
class BEARModules(SACModules):
imitator: ConditionalVAE
vae_encoder: VAEEncoder
vae_decoder: VAEDecoder
log_alpha: Parameter
imitator_optim: Optimizer
vae_optim: Optimizer
alpha_optim: Optional[Optimizer]


Expand Down Expand Up @@ -135,15 +137,16 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:
return (alpha * (mmd - self._alpha_threshold)).mean()

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
self._modules.imitator_optim.zero_grad()
self._modules.vae_optim.zero_grad()
loss = self.compute_imitator_loss(batch)
loss.backward()
self._modules.imitator_optim.step()
self._modules.vae_optim.step()
return {"imitator_loss": float(loss.cpu().detach().numpy())}

def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
return compute_vae_error(
vae=self._modules.imitator,
vae_encoder=self._modules.vae_encoder,
vae_decoder=self._modules.vae_decoder,
x=batch.observations,
action=batch.actions,
beta=self._vae_kl_weight,
Expand All @@ -161,9 +164,10 @@ def update_alpha(self, mmd_loss: torch.Tensor) -> None:
def _compute_mmd(self, x: TorchObservation) -> torch.Tensor:
with torch.no_grad():
behavior_actions = forward_vae_sample_n(
self._modules.imitator,
x,
self._n_mmd_action_samples,
vae_decoder=self._modules.vae_decoder,
x=x,
latent_size=2 * self._action_size,
n=self._n_mmd_action_samples,
with_squash=False,
)
dist = build_squashed_gaussian_distribution(self._modules.policy(x))
Expand Down
Loading

0 comments on commit ee65b1f

Please sign in to comment.