From ee65b1f47371efd0ae3e2aaebc2946cf2f9df56d Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 13 Jan 2024 14:22:33 +0900 Subject: [PATCH] Remove ConditionalVAE --- d3rlpy/algos/qlearning/bcq.py | 23 ++++-- d3rlpy/algos/qlearning/bear.py | 23 ++++-- d3rlpy/algos/qlearning/plas.py | 43 ++++++++--- d3rlpy/algos/qlearning/torch/bcq_impl.py | 27 +++---- d3rlpy/algos/qlearning/torch/bear_impl.py | 22 +++--- d3rlpy/algos/qlearning/torch/plas_impl.py | 38 +++++----- d3rlpy/models/builders.py | 46 ++++++----- d3rlpy/models/torch/imitators.py | 76 +++++-------------- reproductions/offline/bear.py | 2 +- reproductions/offline/plas.py | 2 +- .../offline/plas_with_perturbation.py | 2 +- tests/models/test_builders.py | 41 ++++++++-- tests/models/torch/test_imitators.py | 43 +++-------- 13 files changed, 204 insertions(+), 184 deletions(-) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 7cc50255..a66b3d43 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -4,10 +4,11 @@ from ...constants import ActionSpace from ...models.builders import ( create_categorical_policy, - create_conditional_vae, create_continuous_q_function, create_deterministic_residual_policy, create_discrete_q_function, + create_vae_decoder, + create_vae_encoder, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field @@ -200,7 +201,7 @@ def inner_create_impl( n_ensembles=self._config.n_critics, device=self._device, ) - imitator = create_conditional_vae( + vae_encoder = create_vae_encoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, @@ -209,6 +210,13 @@ def inner_create_impl( encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) + vae_decoder = create_vae_decoder( + observation_shape=observation_shape, + action_size=action_size, + latent_size=2 * action_size, + encoder_factory=self._config.imitator_encoder_factory, + device=self._device, + ) actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate @@ -216,8 +224,10 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate ) - imitator_optim = self._config.imitator_optim_factory.create( - imitator.named_modules(), lr=self._config.imitator_learning_rate + vae_optim = self._config.imitator_optim_factory.create( + list(vae_encoder.named_modules()) + + list(vae_decoder.named_modules()), + lr=self._config.imitator_learning_rate, ) modules = BCQModules( @@ -225,10 +235,11 @@ def inner_create_impl( targ_policy=targ_policy, q_funcs=q_funcs, targ_q_funcs=targ_q_funcs, - imitator=imitator, + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, actor_optim=actor_optim, critic_optim=critic_optim, - imitator_optim=imitator_optim, + vae_optim=vae_optim, ) self._impl = BCQImpl( diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index edc17910..b02bf35c 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -4,10 +4,11 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace from ...models.builders import ( - create_conditional_vae, create_continuous_q_function, create_normal_policy, create_parameter, + create_vae_decoder, + create_vae_encoder, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field @@ -178,7 +179,7 @@ def inner_create_impl( n_ensembles=self._config.n_critics, device=self._device, ) - imitator = create_conditional_vae( + vae_encoder = create_vae_encoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, @@ -187,6 +188,13 @@ def inner_create_impl( encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) + vae_decoder = create_vae_decoder( + observation_shape=observation_shape, + action_size=action_size, + latent_size=2 * action_size, + encoder_factory=self._config.imitator_encoder_factory, + device=self._device, + ) log_temp = create_parameter( (1, 1), math.log(self._config.initial_temperature), @@ -202,8 +210,10 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate ) - imitator_optim = self._config.imitator_optim_factory.create( - imitator.named_modules(), lr=self._config.imitator_learning_rate + vae_optim = self._config.imitator_optim_factory.create( + list(vae_encoder.named_modules()) + + list(vae_decoder.named_modules()), + lr=self._config.imitator_learning_rate, ) temp_optim = self._config.temp_optim_factory.create( log_temp.named_modules(), lr=self._config.temp_learning_rate @@ -216,12 +226,13 @@ def inner_create_impl( policy=policy, q_funcs=q_funcs, targ_q_funcs=targ_q_funcs, - imitator=imitator, + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, log_temp=log_temp, log_alpha=log_alpha, actor_optim=actor_optim, critic_optim=critic_optim, - imitator_optim=imitator_optim, + vae_optim=vae_optim, temp_optim=temp_optim, alpha_optim=alpha_optim, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index dfad71d1..0db66d53 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -3,10 +3,11 @@ from ...base import DeviceArg, LearnableConfig, register_learnable from ...constants import ActionSpace from ...models.builders import ( - create_conditional_vae, create_continuous_q_function, create_deterministic_policy, create_deterministic_residual_policy, + create_vae_decoder, + create_vae_encoder, ) from ...models.encoders import EncoderFactory, make_encoder_field from ...models.optimizers import OptimizerFactory, make_optimizer_field @@ -135,7 +136,7 @@ def inner_create_impl( n_ensembles=self._config.n_critics, device=self._device, ) - imitator = create_conditional_vae( + vae_encoder = create_vae_encoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, @@ -144,6 +145,13 @@ def inner_create_impl( encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) + vae_decoder = create_vae_decoder( + observation_shape=observation_shape, + action_size=action_size, + latent_size=2 * action_size, + encoder_factory=self._config.imitator_encoder_factory, + device=self._device, + ) actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate @@ -151,8 +159,10 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate ) - imitator_optim = self._config.critic_optim_factory.create( - imitator.named_modules(), lr=self._config.imitator_learning_rate + vae_optim = self._config.critic_optim_factory.create( + list(vae_encoder.named_modules()) + + list(vae_decoder.named_modules()), + lr=self._config.imitator_learning_rate, ) modules = PLASModules( @@ -160,10 +170,11 @@ def inner_create_impl( targ_policy=targ_policy, q_funcs=q_funcs, targ_q_funcs=targ_q_funcs, - imitator=imitator, + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, actor_optim=actor_optim, critic_optim=critic_optim, - imitator_optim=imitator_optim, + vae_optim=vae_optim, ) self._impl = PLASImpl( @@ -272,7 +283,7 @@ def inner_create_impl( n_ensembles=self._config.n_critics, device=self._device, ) - imitator = create_conditional_vae( + vae_encoder = create_vae_encoder( observation_shape=observation_shape, action_size=action_size, latent_size=2 * action_size, @@ -281,6 +292,13 @@ def inner_create_impl( encoder_factory=self._config.imitator_encoder_factory, device=self._device, ) + vae_decoder = create_vae_decoder( + observation_shape=observation_shape, + action_size=action_size, + latent_size=2 * action_size, + encoder_factory=self._config.imitator_encoder_factory, + device=self._device, + ) perturbation = create_deterministic_residual_policy( observation_shape=observation_shape, action_size=action_size, @@ -304,8 +322,10 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate ) - imitator_optim = self._config.critic_optim_factory.create( - imitator.named_modules(), lr=self._config.imitator_learning_rate + vae_optim = self._config.critic_optim_factory.create( + list(vae_encoder.named_modules()) + + list(vae_decoder.named_modules()), + lr=self._config.imitator_learning_rate, ) modules = PLASWithPerturbationModules( @@ -313,12 +333,13 @@ def inner_create_impl( targ_policy=targ_policy, q_funcs=q_funcs, targ_q_funcs=targ_q_funcs, - imitator=imitator, + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, perturbation=perturbation, targ_perturbation=targ_perturbation, actor_optim=actor_optim, critic_optim=critic_optim, - imitator_optim=imitator_optim, + vae_optim=vae_optim, ) self._impl = PLASWithPerturbationImpl( diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 5b519e1d..a16074a6 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -9,14 +9,14 @@ from ....models.torch import ( ActionOutput, CategoricalPolicy, - ConditionalVAE, ContinuousEnsembleQFunctionForwarder, DeterministicResidualPolicy, DiscreteEnsembleQFunctionForwarder, + VAEDecoder, + VAEEncoder, compute_discrete_imitation_loss, compute_max_with_n_actions, compute_vae_error, - forward_vae_decode, ) from ....torch_utility import ( TorchMiniBatch, @@ -42,8 +42,9 @@ class BCQModules(DDPGBaseModules): policy: DeterministicResidualPolicy targ_policy: DeterministicResidualPolicy - imitator: ConditionalVAE - imitator_optim: Optimizer + vae_encoder: VAEEncoder + vae_decoder: VAEDecoder + vae_optim: Optimizer class BCQImpl(DDPGBaseImpl): @@ -95,16 +96,17 @@ def compute_actor_loss( return DDPGBaseActorLoss(-value[0].mean()) def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._modules.imitator_optim.zero_grad() + self._modules.vae_optim.zero_grad() loss = compute_vae_error( - vae=self._modules.imitator, + vae_encoder=self._modules.vae_encoder, + vae_decoder=self._modules.vae_decoder, x=batch.observations, action=batch.actions, beta=self._beta, ) loss.backward() - self._modules.imitator_optim.step() - return {"imitator_loss": float(loss.cpu().detach().numpy())} + self._modules.vae_optim.step() + return {"vae_loss": float(loss.cpu().detach().numpy())} def _repeat_observation(self, x: TorchObservation) -> TorchObservation: # (batch_size, *obs_shape) -> (batch_size, n, *obs_shape) @@ -126,11 +128,7 @@ def _sample_repeated_action( ) clipped_latent = latent.clamp(-0.5, 0.5) # sample action - sampled_action = forward_vae_decode( - vae=self._modules.imitator, - x=flattened_x, - latent=clipped_latent, - ) + sampled_action = self._modules.vae_decoder(flattened_x, clipped_latent) # add residual action policy = self._modules.targ_policy if target else self._modules.policy action = policy(flattened_x, sampled_action) @@ -196,8 +194,7 @@ def inner_update( batch_size, 2 * self._action_size, device=self._device ) clipped_latent = latent.clamp(-0.5, 0.5) - sampled_action = forward_vae_decode( - vae=self._modules.imitator, + sampled_action = self._modules.vae_decoder( x=batch.observations, latent=clipped_latent, ) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index f0d1caaf..7332b6fc 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -6,9 +6,10 @@ from ....models.torch import ( ActionOutput, - ConditionalVAE, ContinuousEnsembleQFunctionForwarder, Parameter, + VAEDecoder, + VAEEncoder, build_squashed_gaussian_distribution, compute_max_with_n_actions_and_indices, compute_vae_error, @@ -43,9 +44,10 @@ def _laplacian_kernel( @dataclasses.dataclass(frozen=True) class BEARModules(SACModules): - imitator: ConditionalVAE + vae_encoder: VAEEncoder + vae_decoder: VAEDecoder log_alpha: Parameter - imitator_optim: Optimizer + vae_optim: Optimizer alpha_optim: Optional[Optimizer] @@ -135,15 +137,16 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor: return (alpha * (mmd - self._alpha_threshold)).mean() def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._modules.imitator_optim.zero_grad() + self._modules.vae_optim.zero_grad() loss = self.compute_imitator_loss(batch) loss.backward() - self._modules.imitator_optim.step() + self._modules.vae_optim.step() return {"imitator_loss": float(loss.cpu().detach().numpy())} def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: return compute_vae_error( - vae=self._modules.imitator, + vae_encoder=self._modules.vae_encoder, + vae_decoder=self._modules.vae_decoder, x=batch.observations, action=batch.actions, beta=self._vae_kl_weight, @@ -161,9 +164,10 @@ def update_alpha(self, mmd_loss: torch.Tensor) -> None: def _compute_mmd(self, x: TorchObservation) -> torch.Tensor: with torch.no_grad(): behavior_actions = forward_vae_sample_n( - self._modules.imitator, - x, - self._n_mmd_action_samples, + vae_decoder=self._modules.vae_decoder, + x=x, + latent_size=2 * self._action_size, + n=self._n_mmd_action_samples, with_squash=False, ) dist = build_squashed_gaussian_distribution(self._modules.policy(x)) diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 74abf45a..8d61f08c 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -6,12 +6,12 @@ from ....models.torch import ( ActionOutput, - ConditionalVAE, ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, DeterministicResidualPolicy, + VAEDecoder, + VAEEncoder, compute_vae_error, - forward_vae_decode, ) from ....torch_utility import TorchMiniBatch, soft_sync from ....types import Shape, TorchObservation @@ -29,8 +29,9 @@ class PLASModules(DDPGBaseModules): policy: DeterministicPolicy targ_policy: DeterministicPolicy - imitator: ConditionalVAE - imitator_optim: Optimizer + vae_encoder: VAEEncoder + vae_decoder: VAEDecoder + vae_optim: Optimizer class PLASImpl(DDPGBaseImpl): @@ -68,24 +69,23 @@ def __init__( self._warmup_steps = warmup_steps def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: - self._modules.imitator_optim.zero_grad() + self._modules.vae_optim.zero_grad() loss = compute_vae_error( - vae=self._modules.imitator, + vae_encoder=self._modules.vae_encoder, + vae_decoder=self._modules.vae_decoder, x=batch.observations, action=batch.actions, beta=self._beta, ) loss.backward() - self._modules.imitator_optim.step() - return {"imitator_loss": float(loss.cpu().detach().numpy())} + self._modules.vae_optim.step() + return {"vae_loss": float(loss.cpu().detach().numpy())} def compute_actor_loss( self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: latent_actions = 2.0 * action.squashed_mu - actions = forward_vae_decode( - self._modules.imitator, batch.observations, latent_actions - ) + actions = self._modules.vae_decoder(batch.observations, latent_actions) loss = -self._q_func_forwarder.compute_expected_q( batch.observations, actions, "none" )[0].mean() @@ -93,7 +93,7 @@ def compute_actor_loss( def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: latent_actions = 2.0 * self._modules.policy(x).squashed_mu - return forward_vae_decode(self._modules.imitator, x, latent_actions) + return self._modules.vae_decoder(x, latent_actions) def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: return self.inner_predict_best_action(x) @@ -104,8 +104,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: 2.0 * self._modules.targ_policy(batch.next_observations).squashed_mu ) - actions = forward_vae_decode( - self._modules.imitator, batch.next_observations, latent_actions + actions = self._modules.vae_decoder( + batch.next_observations, latent_actions ) return self._targ_q_func_forwarder.compute_target( batch.next_observations, @@ -175,9 +175,7 @@ def compute_actor_loss( self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: latent_actions = 2.0 * action.squashed_mu - actions = forward_vae_decode( - self._modules.imitator, batch.observations, latent_actions - ) + actions = self._modules.vae_decoder(batch.observations, latent_actions) residual_actions = self._modules.perturbation( batch.observations, actions ).squashed_mu @@ -188,7 +186,7 @@ def compute_actor_loss( def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: latent_actions = 2.0 * self._modules.policy(x).squashed_mu - actions = forward_vae_decode(self._modules.imitator, x, latent_actions) + actions = self._modules.vae_decoder(x, latent_actions) return self._modules.perturbation(x, actions).squashed_mu def inner_sample_action(self, x: TorchObservation) -> torch.Tensor: @@ -200,8 +198,8 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: 2.0 * self._modules.targ_policy(batch.next_observations).squashed_mu ) - actions = forward_vae_decode( - self._modules.imitator, batch.next_observations, latent_actions + actions = self._modules.vae_decoder( + batch.next_observations, latent_actions ) residual_actions = self._modules.targ_perturbation( batch.next_observations, actions diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index e6ec8d4a..82b34c79 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -9,7 +9,6 @@ from .q_functions import QFunctionFactory from .torch import ( CategoricalPolicy, - ConditionalVAE, ContinuousDecisionTransformer, ContinuousEnsembleQFunctionForwarder, DeterministicPolicy, @@ -35,7 +34,8 @@ "create_deterministic_residual_policy", "create_categorical_policy", "create_normal_policy", - "create_conditional_vae", + "create_vae_encoder", + "create_vae_decoder", "create_value_function", "create_parameter", "create_continuous_decision_transformer", @@ -195,7 +195,7 @@ def create_categorical_policy( return policy -def create_conditional_vae( +def create_vae_encoder( observation_shape: Shape, action_size: int, latent_size: int, @@ -203,34 +203,40 @@ def create_conditional_vae( device: str, min_logstd: float = -20.0, max_logstd: float = 2.0, -) -> ConditionalVAE: - encoder_encoder = encoder_factory.create_with_action( - observation_shape, action_size - ) - decoder_encoder = encoder_factory.create_with_action( - observation_shape, latent_size - ) +) -> VAEEncoder: + encoder = encoder_factory.create_with_action(observation_shape, action_size) encoder_hidden_size = compute_output_size( - [observation_shape, (action_size,)], encoder_encoder - ) - decoder_hidden_size = compute_output_size( - [observation_shape, (latent_size,)], decoder_encoder + [observation_shape, (action_size,)], encoder ) - encoder = VAEEncoder( - encoder=encoder_encoder, + vae_encoder = VAEEncoder( + encoder=encoder, hidden_size=encoder_hidden_size, latent_size=latent_size, min_logstd=min_logstd, max_logstd=max_logstd, ) + vae_encoder.to(device) + return vae_encoder + + +def create_vae_decoder( + observation_shape: Shape, + action_size: int, + latent_size: int, + encoder_factory: EncoderFactory, + device: str, +) -> VAEDecoder: + encoder = encoder_factory.create_with_action(observation_shape, latent_size) + decoder_hidden_size = compute_output_size( + [observation_shape, (latent_size,)], encoder + ) decoder = VAEDecoder( - encoder=decoder_encoder, + encoder=encoder, hidden_size=decoder_hidden_size, action_size=action_size, ) - policy = ConditionalVAE(encoder=encoder, decoder=decoder) - policy.to(device) - return policy + decoder.to(device) + return decoder def create_value_function( diff --git a/d3rlpy/models/torch/imitators.py b/d3rlpy/models/torch/imitators.py index 65479ab5..0e904d3f 100644 --- a/d3rlpy/models/torch/imitators.py +++ b/d3rlpy/models/torch/imitators.py @@ -19,9 +19,6 @@ __all__ = [ "VAEEncoder", "VAEDecoder", - "ConditionalVAE", - "forward_vae_encode", - "forward_vae_decode", "forward_vae_sample", "forward_vae_sample_n", "compute_vae_error", @@ -104,64 +101,27 @@ def action_size(self) -> int: return self._action_size -class ConditionalVAE(nn.Module): # type: ignore - _encoder: VAEEncoder - _decoder: VAEDecoder - _beta: float - - def __init__(self, encoder: VAEEncoder, decoder: VAEDecoder): - super().__init__() - self._encoder = encoder - self._decoder = decoder - - def forward( - self, x: TorchObservation, action: torch.Tensor - ) -> torch.Tensor: - dist = self._encoder(x, action) - return self._decoder(x, dist.rsample()) - - def __call__( - self, x: TorchObservation, action: torch.Tensor - ) -> torch.Tensor: - return cast(torch.Tensor, super().__call__(x, action)) - - @property - def encoder(self) -> VAEEncoder: - return self._encoder - - @property - def decoder(self) -> VAEDecoder: - return self._decoder - - -def forward_vae_encode( - vae: ConditionalVAE, x: TorchObservation, action: torch.Tensor -) -> Normal: - return vae.encoder(x, action) - - -def forward_vae_decode( - vae: ConditionalVAE, x: TorchObservation, latent: torch.Tensor -) -> torch.Tensor: - return vae.decoder(x, latent) - - def forward_vae_sample( - vae: ConditionalVAE, x: TorchObservation, with_squash: bool = True + vae_decoder: VAEDecoder, + x: TorchObservation, + latent_size: int, + with_squash: bool = True, ) -> torch.Tensor: batch_size = get_batch_size(x) - latent = torch.randn( - (batch_size, vae.encoder.latent_size), device=get_device(x) - ) + latent = torch.randn((batch_size, latent_size), device=get_device(x)) # to prevent extreme numbers - return vae.decoder(x, latent.clamp(-0.5, 0.5), with_squash=with_squash) + return vae_decoder(x, latent.clamp(-0.5, 0.5), with_squash=with_squash) def forward_vae_sample_n( - vae: ConditionalVAE, x: TorchObservation, n: int, with_squash: bool = True + vae_decoder: VAEDecoder, + x: TorchObservation, + latent_size: int, + n: int, + with_squash: bool = True, ) -> torch.Tensor: batch_size = get_batch_size(x) - flat_latent_shape = (n * batch_size, vae.encoder.latent_size) + flat_latent_shape = (n * batch_size, latent_size) flat_latent = torch.randn(flat_latent_shape, device=get_device(x)) # to prevent extreme numbers clipped_latent = flat_latent.clamp(-0.5, 0.5) @@ -177,7 +137,7 @@ def forward_vae_sample_n( # (n, batch, obs) -> (n * batch, obs) flat_x = [_x.reshape(-1, *_x.shape[2:]) for _x in repeated_x] - flat_actions = vae.decoder(flat_x, clipped_latent, with_squash=with_squash) + flat_actions = vae_decoder(flat_x, clipped_latent, with_squash=with_squash) # (n * batch, action) -> (n, batch, action) actions = flat_actions.view(n, batch_size, -1) @@ -187,11 +147,15 @@ def forward_vae_sample_n( def compute_vae_error( - vae: ConditionalVAE, x: TorchObservation, action: torch.Tensor, beta: float + vae_encoder: VAEEncoder, + vae_decoder: VAEDecoder, + x: TorchObservation, + action: torch.Tensor, + beta: float, ) -> torch.Tensor: - dist = vae.encoder(x, action) + dist = vae_encoder(x, action) kl_loss = kl_divergence(dist, Normal(0.0, 1.0)).mean() - y = vae.decoder(x, dist.rsample()) + y = vae_decoder(x, dist.rsample()) return F.mse_loss(y, action) + cast(torch.Tensor, beta * kl_loss) diff --git a/reproductions/offline/bear.py b/reproductions/offline/bear.py index cbeaf578..184ca194 100644 --- a/reproductions/offline/bear.py +++ b/reproductions/offline/bear.py @@ -18,7 +18,7 @@ def main() -> None: vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) - if "halfcheetah" in env.unwrapped.spec.id.lower(): + if "halfcheetah" in args.dataset: kernel = "gaussian" else: kernel = "laplacian" diff --git a/reproductions/offline/plas.py b/reproductions/offline/plas.py index 6dd394c2..a1d55ede 100644 --- a/reproductions/offline/plas.py +++ b/reproductions/offline/plas.py @@ -16,7 +16,7 @@ def main() -> None: d3rlpy.seed(args.seed) d3rlpy.envs.seed_env(env, args.seed) - if "medium-replay" in env.unwrapped.spec.id.lower(): + if "medium-replay" in args.dataset: vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([128, 128]) else: vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) diff --git a/reproductions/offline/plas_with_perturbation.py b/reproductions/offline/plas_with_perturbation.py index 15e8a28f..cc15e9eb 100644 --- a/reproductions/offline/plas_with_perturbation.py +++ b/reproductions/offline/plas_with_perturbation.py @@ -31,7 +31,7 @@ def main() -> None: d3rlpy.seed(args.seed) d3rlpy.envs.seed_env(env, args.seed) - if "medium-replay" in env.unwrapped.spec.id.lower(): + if "medium-replay" in args.dataset: vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([128, 128]) else: vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index 226f0e6e..3c1855eb 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -7,7 +7,6 @@ from d3rlpy.constants import PositionEncodingType from d3rlpy.models.builders import ( create_categorical_policy, - create_conditional_vae, create_continuous_decision_transformer, create_continuous_q_function, create_deterministic_policy, @@ -16,6 +15,8 @@ create_discrete_q_function, create_normal_policy, create_parameter, + create_vae_decoder, + create_vae_encoder, create_value_function, ) from d3rlpy.models.encoders import DefaultEncoderFactory, EncoderFactory @@ -25,7 +26,7 @@ DiscreteEnsembleQFunctionForwarder, get_parameter, ) -from d3rlpy.models.torch.imitators import ConditionalVAE +from d3rlpy.models.torch.imitators import VAEDecoder, VAEEncoder from d3rlpy.models.torch.policies import ( CategoricalPolicy, DeterministicPolicy, @@ -212,14 +213,14 @@ def test_create_continuous_q_function( @pytest.mark.parametrize("latent_size", [32]) @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) -def test_create_conditional_vae( +def test_create_vae_encoder( observation_shape: Sequence[int], action_size: int, latent_size: int, batch_size: int, encoder_factory: EncoderFactory, ) -> None: - vae = create_conditional_vae( + vae_encoder = create_vae_encoder( observation_shape, action_size, latent_size, @@ -227,11 +228,39 @@ def test_create_conditional_vae( device="cpu:0", ) - assert isinstance(vae, ConditionalVAE) + assert isinstance(vae_encoder, VAEEncoder) x = torch.rand((batch_size, *observation_shape)) action = torch.rand(batch_size, action_size) - y = vae(x, action) + dist = vae_encoder(x, action) + assert dist.mean.shape == (batch_size, latent_size) + + +@pytest.mark.parametrize("observation_shape", [(4, 84, 84), (100,)]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("latent_size", [32]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("encoder_factory", [DefaultEncoderFactory()]) +def test_create_vae_decoder( + observation_shape: Sequence[int], + action_size: int, + latent_size: int, + batch_size: int, + encoder_factory: EncoderFactory, +) -> None: + vae_decoder = create_vae_decoder( + observation_shape, + action_size, + latent_size, + encoder_factory, + device="cpu:0", + ) + + assert isinstance(vae_decoder, VAEDecoder) + + x = torch.rand((batch_size, *observation_shape)) + latent = torch.rand(batch_size, latent_size) + y = vae_decoder(x, latent) assert y.shape == (batch_size, action_size) diff --git a/tests/models/torch/test_imitators.py b/tests/models/torch/test_imitators.py index 38b52288..6c1f33b9 100644 --- a/tests/models/torch/test_imitators.py +++ b/tests/models/torch/test_imitators.py @@ -2,15 +2,12 @@ import torch from d3rlpy.models.torch.imitators import ( - ConditionalVAE, VAEDecoder, VAEEncoder, compute_deterministic_imitation_loss, compute_discrete_imitation_loss, compute_stochastic_imitation_loss, compute_vae_error, - forward_vae_decode, - forward_vae_encode, forward_vae_sample, forward_vae_sample_n, ) @@ -57,11 +54,13 @@ def test_vae_encoder( @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("latent_size", [32]) @pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("n", [100]) def test_vae_decoder( observation_shape: Shape, action_size: int, latent_size: int, batch_size: int, + n: int, ) -> None: encoder = DummyEncoderWithAction(observation_shape, latent_size) vae_decoder = VAEDecoder( @@ -76,6 +75,14 @@ def test_vae_decoder( action = vae_decoder(x, latent) assert action.shape == (batch_size, action_size) + # check forward_vae_sample + y = forward_vae_sample(vae_decoder, x, latent_size) + assert y.shape == (batch_size, action_size) + + # check forward_vae_sample_n + y = forward_vae_sample_n(vae_decoder, x, latent_size, n) + assert y.shape == (batch_size, n, action_size) + # check layer connections check_parameter_updates(vae_decoder, (x, latent)) @@ -84,14 +91,12 @@ def test_vae_decoder( @pytest.mark.parametrize("action_size", [2]) @pytest.mark.parametrize("latent_size", [32]) @pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("n", [100]) @pytest.mark.parametrize("beta", [0.5]) def test_conditional_vae( observation_shape: Shape, action_size: int, latent_size: int, batch_size: int, - n: int, beta: float, ) -> None: encoder_encoder = DummyEncoderWithAction(observation_shape, action_size) @@ -106,40 +111,14 @@ def test_conditional_vae( hidden_size=decoder_encoder.get_feature_size(), action_size=action_size, ) - vae = ConditionalVAE( - encoder=vae_encoder, - decoder=vae_decoder, - ) - # check output shape x = create_torch_observations(observation_shape, batch_size) action = torch.rand(batch_size, action_size) - y = vae(x, action) - assert y.shape == (batch_size, action_size) - - # test encode - dist = forward_vae_encode(vae, x, action) - assert dist.mean.shape == (batch_size, latent_size) - - # test decode - y = forward_vae_decode(vae, x, dist.sample()) - assert y.shape == (batch_size, action_size) - - # test decode sample - y = forward_vae_sample(vae, x) - assert y.shape == (batch_size, action_size) - - # test decode sample n - y = forward_vae_sample_n(vae, x, n) - assert y.shape == (batch_size, n, action_size) # test compute error - error = compute_vae_error(vae, x, action, beta) + error = compute_vae_error(vae_encoder, vae_decoder, x, action, beta) assert error.ndim == 0 - # check layer connections - check_parameter_updates(vae, (x, action)) - @pytest.mark.parametrize("observation_shape", [(100,), ((100,), (200,))]) @pytest.mark.parametrize("action_size", [2])