diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index 7e1ae56c..8f8a899c 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -86,8 +86,10 @@ class AWACConfig(LearnableConfig): n_action_samples: int = 1 n_critics: int = 2 - def create(self, device: DeviceArg = False) -> "AWAC": - return AWAC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "AWAC": + return AWAC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -106,6 +108,7 @@ def inner_create_impl( max_logstd=0.0, use_std_parameter=True, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -114,6 +117,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -122,6 +126,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/base.py b/d3rlpy/algos/qlearning/base.py index 42529576..51e38161 100644 --- a/d3rlpy/algos/qlearning/base.py +++ b/d3rlpy/algos/qlearning/base.py @@ -385,7 +385,6 @@ def fit( evaluators: Optional[Dict[str, EvaluatorProtocol]] = None, callback: Optional[Callable[[Self, int, int], None]] = None, epoch_callback: Optional[Callable[[Self, int, int], None]] = None, - enable_ddp: bool = False, ) -> List[Tuple[int, Dict[str, float]]]: """Trains with given dataset. @@ -414,7 +413,6 @@ def fit( epoch_callback: Callable function that takes ``(algo, epoch, total_step)``, which is called at the end of every epoch. - enable_ddp: Flag to wrap models with DataDistributedParallel. Returns: List of result tuples (epoch, metrics) per epoch. @@ -434,7 +432,6 @@ def fit( evaluators=evaluators, callback=callback, epoch_callback=epoch_callback, - enable_ddp=enable_ddp, ) ) return results @@ -454,7 +451,6 @@ def fitter( evaluators: Optional[Dict[str, EvaluatorProtocol]] = None, callback: Optional[Callable[[Self, int, int], None]] = None, epoch_callback: Optional[Callable[[Self, int, int], None]] = None, - enable_ddp: bool = False, ) -> Generator[Tuple[int, Dict[str, float]], None, None]: """Iterate over epochs steps to train with the given dataset. At each iteration algo methods and properties can be changed or queried. @@ -486,7 +482,6 @@ def fitter( epoch_callback: Callable function that takes ``(algo, epoch, total_step)``, which is called at the end of every epoch. - enable_ddp: Flag to wrap models with DataDistributedParallel. Returns: Iterator yielding current epoch and metrics dict. @@ -522,11 +517,6 @@ def fitter( else: LOG.warning("Skip building models since they're already built.") - # wrap all PyTorch modules with DataDistributedParallel - if enable_ddp: - assert self._impl - self._impl.wrap_models_by_ddp() - # save hyperparameters save_config(self, logger) diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index c614e80d..fd34adbc 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -57,8 +57,10 @@ class BCConfig(LearnableConfig): optim_factory: OptimizerFactory = make_optimizer_field() encoder_factory: EncoderFactory = make_encoder_field() - def create(self, device: DeviceArg = False) -> "BC": - return BC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "BC": + return BC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -75,6 +77,7 @@ def inner_create_impl( action_size, self._config.encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) elif self._config.policy_type == "stochastic": imitator = create_normal_policy( @@ -84,6 +87,7 @@ def inner_create_impl( min_logstd=-4.0, max_logstd=15.0, device=self._device, + enable_ddp=self._enable_ddp, ) else: raise ValueError(f"invalid policy_type: {self._config.policy_type}") @@ -141,8 +145,10 @@ class DiscreteBCConfig(LearnableConfig): encoder_factory: EncoderFactory = make_encoder_field() beta: float = 0.5 - def create(self, device: DeviceArg = False) -> "DiscreteBC": - return DiscreteBC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DiscreteBC": + return DiscreteBC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -158,6 +164,7 @@ def inner_create_impl( action_size, self._config.encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index d905f7a0..a23ab2b5 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -160,8 +160,10 @@ class BCQConfig(LearnableConfig): rl_start_step: int = 0 beta: float = 0.5 - def create(self, device: DeviceArg = False) -> "BCQ": - return BCQ(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "BCQ": + return BCQ(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -178,6 +180,7 @@ def inner_create_impl( self._config.action_flexibility, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_residual_policy( observation_shape, @@ -185,6 +188,7 @@ def inner_create_impl( self._config.action_flexibility, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -193,6 +197,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -201,6 +206,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) vae_encoder = create_vae_encoder( observation_shape=observation_shape, @@ -210,6 +216,7 @@ def inner_create_impl( max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) vae_decoder = create_vae_decoder( observation_shape=observation_shape, @@ -217,6 +224,7 @@ def inner_create_impl( latent_size=2 * action_size, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( @@ -337,8 +345,10 @@ class DiscreteBCQConfig(LearnableConfig): target_update_interval: int = 8000 share_encoder: bool = True - def create(self, device: DeviceArg = False) -> "DiscreteBCQ": - return DiscreteBCQ(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DiscreteBCQ": + return DiscreteBCQ(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -356,6 +366,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, @@ -364,6 +375,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) # share convolutional layers if observation is pixel @@ -384,6 +396,7 @@ def inner_create_impl( action_size, self._config.encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_func_params = list(q_funcs.named_modules()) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index d6dfc7be..b741d792 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -146,8 +146,10 @@ class BEARConfig(LearnableConfig): vae_kl_weight: float = 0.5 warmup_steps: int = 40000 - def create(self, device: DeviceArg = False) -> "BEAR": - return BEAR(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "BEAR": + return BEAR(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -163,6 +165,7 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -171,6 +174,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -179,6 +183,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) vae_encoder = create_vae_encoder( observation_shape=observation_shape, @@ -188,6 +193,7 @@ def inner_create_impl( max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) vae_decoder = create_vae_decoder( observation_shape=observation_shape, @@ -195,14 +201,19 @@ def inner_create_impl( latent_size=2 * action_size, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) log_temp = create_parameter( (1, 1), math.log(self._config.initial_temperature), device=self._device, + enable_ddp=self._enable_ddp, ) log_alpha = create_parameter( - (1, 1), math.log(self._config.initial_alpha), device=self._device + (1, 1), + math.log(self._config.initial_alpha), + device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py index d2a636dd..181d4112 100644 --- a/d3rlpy/algos/qlearning/cal_ql.py +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -71,8 +71,10 @@ class CalQLConfig(CQLConfig): max_q_backup (bool): Flag to sample max Q-values for target. """ - def create(self, device: DeviceArg = False) -> "CalQL": - return CalQL(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "CalQL": + return CalQL(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -92,6 +94,7 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_fowarder = create_continuous_q_function( observation_shape, @@ -100,6 +103,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -108,14 +112,19 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) log_temp = create_parameter( (1, 1), math.log(self._config.initial_temperature), device=self._device, + enable_ddp=self._enable_ddp, ) log_alpha = create_parameter( - (1, 1), math.log(self._config.initial_alpha), device=self._device + (1, 1), + math.log(self._config.initial_alpha), + device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index bbb40e03..a0cd482f 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -125,8 +125,10 @@ class CQLConfig(LearnableConfig): soft_q_backup: bool = False max_q_backup: bool = False - def create(self, device: DeviceArg = False) -> "CQL": - return CQL(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "CQL": + return CQL(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -146,6 +148,7 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_fowarder = create_continuous_q_function( observation_shape, @@ -154,6 +157,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -162,14 +166,19 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) log_temp = create_parameter( (1, 1), math.log(self._config.initial_temperature), device=self._device, + enable_ddp=self._enable_ddp, ) log_alpha = create_parameter( - (1, 1), math.log(self._config.initial_alpha), device=self._device + (1, 1), + math.log(self._config.initial_alpha), + device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( @@ -275,8 +284,10 @@ class DiscreteCQLConfig(LearnableConfig): target_update_interval: int = 8000 alpha: float = 1.0 - def create(self, device: DeviceArg = False) -> "DiscreteCQL": - return DiscreteCQL(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DiscreteCQL": + return DiscreteCQL(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -294,6 +305,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, @@ -302,6 +314,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index fd29cc56..38aac0d6 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -121,8 +121,10 @@ class CRRConfig(LearnableConfig): target_update_interval: int = 100 update_actor_interval: int = 1 - def create(self, device: DeviceArg = False) -> "CRR": - return CRR(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "CRR": + return CRR(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -138,12 +140,14 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_normal_policy( observation_shape, action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -152,6 +156,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -160,6 +165,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index e2dc6d1e..9e273234 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -82,8 +82,10 @@ class DDPGConfig(LearnableConfig): tau: float = 0.005 n_critics: int = 1 - def create(self, device: DeviceArg = False) -> "DDPG": - return DDPG(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DDPG": + return DDPG(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -99,12 +101,14 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_policy( observation_shape, action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -113,6 +117,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -121,6 +126,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index ff729d5a..97447649 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -55,8 +55,10 @@ class DQNConfig(LearnableConfig): n_critics: int = 1 target_update_interval: int = 8000 - def create(self, device: DeviceArg = False) -> "DQN": - return DQN(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DQN": + return DQN(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -74,6 +76,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_forwarder = create_discrete_q_function( observation_shape, @@ -82,6 +85,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( @@ -158,8 +162,10 @@ class DoubleDQNConfig(DQNConfig): n_critics: int = 1 target_update_interval: int = 8000 - def create(self, device: DeviceArg = False) -> "DoubleDQN": - return DoubleDQN(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DoubleDQN": + return DoubleDQN(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -177,6 +183,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_forwarder = create_discrete_q_function( observation_shape, @@ -185,6 +192,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 4f1ce04c..8065c2a4 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -97,8 +97,10 @@ class IQLConfig(LearnableConfig): weight_temp: float = 3.0 max_weight: float = 100.0 - def create(self, device: DeviceArg = False) -> "IQL": - return IQL(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "IQL": + return IQL(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -117,6 +119,7 @@ def inner_create_impl( max_logstd=2.0, use_std_parameter=True, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -125,6 +128,7 @@ def inner_create_impl( MeanQFunctionFactory(), n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -133,11 +137,13 @@ def inner_create_impl( MeanQFunctionFactory(), n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) value_func = create_value_function( observation_shape, self._config.value_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 245b473c..12abec5f 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -57,8 +57,10 @@ class NFQConfig(LearnableConfig): gamma: float = 0.99 n_critics: int = 1 - def create(self, device: DeviceArg = False) -> "NFQ": - return NFQ(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "NFQ": + return NFQ(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -76,6 +78,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, @@ -84,6 +87,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 05b5d625..2ab4c364 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -97,8 +97,10 @@ class PLASConfig(LearnableConfig): warmup_steps: int = 500000 beta: float = 0.5 - def create(self, device: DeviceArg = False) -> "PLAS": - return PLAS(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "PLAS": + return PLAS(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -114,12 +116,14 @@ def inner_create_impl( 2 * action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_policy( observation_shape, 2 * action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -128,6 +132,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -136,6 +141,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) vae_encoder = create_vae_encoder( observation_shape=observation_shape, @@ -145,6 +151,7 @@ def inner_create_impl( max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) vae_decoder = create_vae_decoder( observation_shape=observation_shape, @@ -152,6 +159,7 @@ def inner_create_impl( latent_size=2 * action_size, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( @@ -243,8 +251,10 @@ class PLASWithPerturbationConfig(PLASConfig): action_flexibility: float = 0.05 - def create(self, device: DeviceArg = False) -> "PLASWithPerturbation": - return PLASWithPerturbation(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "PLASWithPerturbation": + return PLASWithPerturbation(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -262,12 +272,14 @@ def inner_create_impl( 2 * action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_policy( observation_shape, 2 * action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -276,6 +288,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -284,6 +297,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) vae_encoder = create_vae_encoder( observation_shape=observation_shape, @@ -293,6 +307,7 @@ def inner_create_impl( max_logstd=15.0, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) vae_decoder = create_vae_decoder( observation_shape=observation_shape, @@ -300,6 +315,7 @@ def inner_create_impl( latent_size=2 * action_size, encoder_factory=self._config.imitator_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) perturbation = create_deterministic_residual_policy( observation_shape=observation_shape, @@ -307,6 +323,7 @@ def inner_create_impl( scale=self._config.action_flexibility, encoder_factory=self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_perturbation = create_deterministic_residual_policy( observation_shape=observation_shape, @@ -314,6 +331,7 @@ def inner_create_impl( scale=self._config.action_flexibility, encoder_factory=self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) named_modules = list(policy.named_modules()) diff --git a/d3rlpy/algos/qlearning/random_policy.py b/d3rlpy/algos/qlearning/random_policy.py index 8190c0da..0623a7c5 100644 --- a/d3rlpy/algos/qlearning/random_policy.py +++ b/d3rlpy/algos/qlearning/random_policy.py @@ -35,7 +35,7 @@ class RandomPolicyConfig(LearnableConfig): distribution: str = "uniform" normal_std: float = 1.0 - def create(self, device: DeviceArg = False) -> "RandomPolicy": # type: ignore + def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "RandomPolicy": # type: ignore return RandomPolicy(self) @staticmethod @@ -47,7 +47,7 @@ class RandomPolicy(QLearningAlgoBase[None, RandomPolicyConfig]): # type: ignore _action_size: int def __init__(self, config: RandomPolicyConfig): - super().__init__(config, False, None) + super().__init__(config, False, False, None) self._action_size = 1 def inner_create_impl( @@ -98,7 +98,7 @@ class DiscreteRandomPolicyConfig(LearnableConfig): ``fit`` and ``fit_online`` methods will raise exceptions. """ - def create(self, device: DeviceArg = False) -> "DiscreteRandomPolicy": # type: ignore + def create(self, device: DeviceArg = False, enable_ddp: bool = False) -> "DiscreteRandomPolicy": # type: ignore return DiscreteRandomPolicy(self) @staticmethod @@ -110,7 +110,7 @@ class DiscreteRandomPolicy(QLearningAlgoBase[None, DiscreteRandomPolicyConfig]): _action_size: int def __init__(self, config: DiscreteRandomPolicyConfig): - super().__init__(config, False, None) + super().__init__(config, False, False, None) self._action_size = 1 def inner_create_impl( diff --git a/d3rlpy/algos/qlearning/rebrac.py b/d3rlpy/algos/qlearning/rebrac.py index 4d7f1e7f..8334dc79 100644 --- a/d3rlpy/algos/qlearning/rebrac.py +++ b/d3rlpy/algos/qlearning/rebrac.py @@ -90,8 +90,10 @@ class ReBRACConfig(LearnableConfig): critic_beta: float = 0.01 update_actor_interval: int = 2 - def create(self, device: DeviceArg = False) -> "ReBRAC": - return ReBRAC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "ReBRAC": + return ReBRAC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -107,12 +109,14 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_policy( observation_shape, action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -121,6 +125,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -129,6 +134,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index c31cb69e..e3bb4efb 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -111,8 +111,10 @@ class SACConfig(LearnableConfig): n_critics: int = 2 initial_temperature: float = 1.0 - def create(self, device: DeviceArg = False) -> "SAC": - return SAC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "SAC": + return SAC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -128,6 +130,7 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -136,6 +139,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -144,11 +148,13 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) log_temp = create_parameter( (1, 1), math.log(self._config.initial_temperature), device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( @@ -260,8 +266,10 @@ class DiscreteSACConfig(LearnableConfig): initial_temperature: float = 1.0 target_update_interval: int = 8000 - def create(self, device: DeviceArg = False) -> "DiscreteSAC": - return DiscreteSAC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DiscreteSAC": + return DiscreteSAC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -279,6 +287,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, @@ -287,18 +296,21 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) policy = create_categorical_policy( observation_shape, action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) if self._config.initial_temperature > 0: log_temp = create_parameter( (1, 1), math.log(self._config.initial_temperature), device=self._device, + enable_ddp=self._enable_ddp, ) else: log_temp = None diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 58ba5757..885434c4 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -91,8 +91,10 @@ class TD3Config(LearnableConfig): target_smoothing_clip: float = 0.5 update_actor_interval: int = 2 - def create(self, device: DeviceArg = False) -> "TD3": - return TD3(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "TD3": + return TD3(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -108,12 +110,14 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_policy( observation_shape, action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -122,6 +126,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -130,6 +135,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 55d08c25..f3c223a8 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -83,8 +83,10 @@ class TD3PlusBCConfig(LearnableConfig): alpha: float = 2.5 update_actor_interval: int = 2 - def create(self, device: DeviceArg = False) -> "TD3PlusBC": - return TD3PlusBC(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "TD3PlusBC": + return TD3PlusBC(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -100,12 +102,14 @@ def inner_create_impl( action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) targ_policy = create_deterministic_policy( observation_shape, action_size, self._config.actor_encoder_factory, device=self._device, + enable_ddp=self._enable_ddp, ) q_funcs, q_func_forwarder = create_continuous_q_function( observation_shape, @@ -114,6 +118,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -122,6 +127,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) actor_optim = self._config.actor_optim_factory.create( diff --git a/d3rlpy/algos/transformer/base.py b/d3rlpy/algos/transformer/base.py index f06e9ce4..0aff030d 100644 --- a/d3rlpy/algos/transformer/base.py +++ b/d3rlpy/algos/transformer/base.py @@ -389,7 +389,6 @@ def fit( eval_action_sampler: Optional[TransformerActionSampler] = None, save_interval: int = 1, callback: Optional[Callable[[Self, int, int], None]] = None, - enable_ddp: bool = False, ) -> None: """Trains with given dataset. @@ -410,7 +409,6 @@ def fit( save_interval: Interval to save parameters. callback: Callable function that takes ``(algo, epoch, total_step)`` , which is called every step. - enable_ddp: Flag to wrap models with DataDistributedParallel. """ LOG.info("dataset info", dataset_info=dataset.dataset_info) @@ -443,11 +441,6 @@ def fit( else: LOG.warning("Skip building models since they're already built.") - # wrap all PyTorch modules with DataDistributedParallel - if enable_ddp: - assert self._impl - self._impl.wrap_models_by_ddp() - # save hyperparameters save_config(self, logger) diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index a5dd2298..e7e51f01 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -83,8 +83,10 @@ class DecisionTransformerConfig(TransformerConfig): clip_grad_norm: float = 0.25 compile: bool = False - def create(self, device: DeviceArg = False) -> "DecisionTransformer": - return DecisionTransformer(self, device) + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "DecisionTransformer": + return DecisionTransformer(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -111,6 +113,7 @@ def inner_create_impl( activation_type=self._config.activation_type, position_encoding_type=self._config.position_encoding_type, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( transformer.named_modules(), lr=self._config.learning_rate @@ -198,9 +201,9 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): compile: bool = False def create( - self, device: DeviceArg = False + self, device: DeviceArg = False, enable_ddp: bool = False ) -> "DiscreteDecisionTransformer": - return DiscreteDecisionTransformer(self, device) + return DiscreteDecisionTransformer(self, device, enable_ddp) @staticmethod def get_type() -> str: @@ -230,6 +233,7 @@ def inner_create_impl( embed_activation_type=self._config.embed_activation_type, position_encoding_type=self._config.position_encoding_type, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( transformer.named_modules(), lr=self._config.learning_rate diff --git a/d3rlpy/base.py b/d3rlpy/base.py index bed3f38d..2b526373 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -86,12 +86,6 @@ def device(self) -> str: def modules(self) -> Modules: return self._modules - def wrap_models_by_ddp(self) -> None: - self._modules = self._modules.wrap_models_by_ddp() - - def unwrap_models_by_ddp(self) -> None: - self._modules = self._modules.unwrap_models_by_ddp() - @dataclasses.dataclass() class LearnableConfig(DynamicConfig): @@ -104,7 +98,7 @@ class LearnableConfig(DynamicConfig): reward_scaler: Optional[RewardScaler] = make_reward_scaler_field() def create( - self, device: DeviceArg = False + self, device: DeviceArg = False, enable_ddp: bool = False ) -> "LearnableBase[ImplBase, LearnableConfig]": r"""Returns algorithm object. @@ -113,6 +107,8 @@ def create( boolean and True, ``cuda:0`` will be used. If the value is integer, ``cuda:`` will be used. If the value is string in torch device style, the specified device will be used. + enable_ddp (bool): Flag to wrap models with DDP to enable Data + Distributed Parallel training. Returns: algorithm object. @@ -210,6 +206,7 @@ def load_learnable( class LearnableBase(Generic[TImpl_co, TConfig_co], metaclass=ABCMeta): _config: TConfig_co _device: str + _enable_ddp: bool _impl: Optional[TImpl_co] _grad_step: int @@ -217,6 +214,7 @@ def __init__( self, config: TConfig_co, device: DeviceArg, + enable_ddp: bool, impl: Optional[TImpl_co] = None, ): if self.get_action_type() == ActionSpace.DISCRETE: @@ -225,6 +223,7 @@ def __init__( ), "action_scaler cannot be used with discrete action-space algorithms." self._config = config self._device = _process_device(device) + self._enable_ddp = enable_ddp self._impl = impl self._grad_step = 0 diff --git a/d3rlpy/models/builders.py b/d3rlpy/models/builders.py index 82b34c79..3238b6ed 100644 --- a/d3rlpy/models/builders.py +++ b/d3rlpy/models/builders.py @@ -4,6 +4,7 @@ from torch import nn from ..constants import PositionEncodingType +from ..torch_utility import wrap_model_by_ddp from ..types import Shape from .encoders import EncoderFactory from .q_functions import QFunctionFactory @@ -49,6 +50,7 @@ def create_discrete_q_function( encoder_factory: EncoderFactory, q_func_factory: QFunctionFactory, device: str, + enable_ddp: bool, n_ensembles: int = 1, ) -> Tuple[nn.ModuleList, DiscreteEnsembleQFunctionForwarder]: if q_func_factory.share_encoder: @@ -67,10 +69,13 @@ def create_discrete_q_function( q_func, forwarder = q_func_factory.create_discrete( encoder, hidden_size, action_size ) + q_func.to(device) + if enable_ddp: + q_func = wrap_model_by_ddp(q_func) + forwarder.set_q_func(q_func) q_funcs.append(q_func) forwarders.append(forwarder) q_func_modules = nn.ModuleList(q_funcs) - q_func_modules.to(device) ensemble_forwarder = DiscreteEnsembleQFunctionForwarder( forwarders, action_size ) @@ -83,6 +88,7 @@ def create_continuous_q_function( encoder_factory: EncoderFactory, q_func_factory: QFunctionFactory, device: str, + enable_ddp: bool, n_ensembles: int = 1, ) -> Tuple[nn.ModuleList, ContinuousEnsembleQFunctionForwarder]: if q_func_factory.share_encoder: @@ -109,10 +115,13 @@ def create_continuous_q_function( q_func, forwarder = q_func_factory.create_continuous( encoder, hidden_size ) + q_func.to(device) + if enable_ddp: + q_func = wrap_model_by_ddp(q_func) + forwarder.set_q_func(q_func) q_funcs.append(q_func) forwarders.append(forwarder) q_func_modules = nn.ModuleList(q_funcs) - q_func_modules.to(device) ensemble_forwarder = ContinuousEnsembleQFunctionForwarder( forwarders, action_size ) @@ -124,6 +133,7 @@ def create_deterministic_policy( action_size: int, encoder_factory: EncoderFactory, device: str, + enable_ddp: bool, ) -> DeterministicPolicy: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) @@ -133,6 +143,8 @@ def create_deterministic_policy( action_size=action_size, ) policy.to(device) + if enable_ddp: + policy = wrap_model_by_ddp(policy) return policy @@ -142,6 +154,7 @@ def create_deterministic_residual_policy( scale: float, encoder_factory: EncoderFactory, device: str, + enable_ddp: bool, ) -> DeterministicResidualPolicy: encoder = encoder_factory.create_with_action(observation_shape, action_size) hidden_size = compute_output_size( @@ -154,6 +167,8 @@ def create_deterministic_residual_policy( scale=scale, ) policy.to(device) + if enable_ddp: + policy = wrap_model_by_ddp(policy) return policy @@ -162,6 +177,7 @@ def create_normal_policy( action_size: int, encoder_factory: EncoderFactory, device: str, + enable_ddp: bool, min_logstd: float = -20.0, max_logstd: float = 2.0, use_std_parameter: bool = False, @@ -177,6 +193,8 @@ def create_normal_policy( use_std_parameter=use_std_parameter, ) policy.to(device) + if enable_ddp: + policy = wrap_model_by_ddp(policy) return policy @@ -185,6 +203,7 @@ def create_categorical_policy( action_size: int, encoder_factory: EncoderFactory, device: str, + enable_ddp: bool, ) -> CategoricalPolicy: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) @@ -192,6 +211,8 @@ def create_categorical_policy( encoder=encoder, hidden_size=hidden_size, action_size=action_size ) policy.to(device) + if enable_ddp: + policy = wrap_model_by_ddp(policy) return policy @@ -201,6 +222,7 @@ def create_vae_encoder( latent_size: int, encoder_factory: EncoderFactory, device: str, + enable_ddp: bool, min_logstd: float = -20.0, max_logstd: float = 2.0, ) -> VAEEncoder: @@ -216,6 +238,8 @@ def create_vae_encoder( max_logstd=max_logstd, ) vae_encoder.to(device) + if enable_ddp: + vae_encoder = wrap_model_by_ddp(vae_encoder) return vae_encoder @@ -225,6 +249,7 @@ def create_vae_decoder( latent_size: int, encoder_factory: EncoderFactory, device: str, + enable_ddp: bool, ) -> VAEDecoder: encoder = encoder_factory.create_with_action(observation_shape, latent_size) decoder_hidden_size = compute_output_size( @@ -236,25 +261,34 @@ def create_vae_decoder( action_size=action_size, ) decoder.to(device) + if enable_ddp: + decoder = wrap_model_by_ddp(decoder) return decoder def create_value_function( - observation_shape: Shape, encoder_factory: EncoderFactory, device: str + observation_shape: Shape, + encoder_factory: EncoderFactory, + device: str, + enable_ddp: bool, ) -> ValueFunction: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) value_func = ValueFunction(encoder, hidden_size) value_func.to(device) + if enable_ddp: + value_func = wrap_model_by_ddp(value_func) return value_func def create_parameter( - shape: Sequence[int], initial_value: float, device: str + shape: Sequence[int], initial_value: float, device: str, enable_ddp: bool ) -> Parameter: data = torch.full(shape, initial_value, dtype=torch.float32) parameter = Parameter(data) parameter.to(device) + if enable_ddp: + parameter = wrap_model_by_ddp(parameter) return parameter @@ -291,6 +325,7 @@ def create_continuous_decision_transformer( activation_type: str, position_encoding_type: PositionEncodingType, device: str, + enable_ddp: bool, ) -> ContinuousDecisionTransformer: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) @@ -316,6 +351,8 @@ def create_continuous_decision_transformer( activation=create_activation(activation_type), ) transformer.to(device) + if enable_ddp: + transformer = wrap_model_by_ddp(transformer) return transformer @@ -334,6 +371,7 @@ def create_discrete_decision_transformer( embed_activation_type: str, position_encoding_type: PositionEncodingType, device: str, + enable_ddp: bool, ) -> DiscreteDecisionTransformer: encoder = encoder_factory.create(observation_shape) hidden_size = compute_output_size([observation_shape], encoder) @@ -360,4 +398,6 @@ def create_discrete_decision_transformer( embed_activation=create_activation(embed_activation_type), ) transformer.to(device) + if enable_ddp: + transformer = wrap_model_by_ddp(transformer) return transformer diff --git a/d3rlpy/models/torch/q_functions/base.py b/d3rlpy/models/torch/q_functions/base.py index dd5a735a..d2d35543 100644 --- a/d3rlpy/models/torch/q_functions/base.py +++ b/d3rlpy/models/torch/q_functions/base.py @@ -80,6 +80,10 @@ def compute_target( ) -> torch.Tensor: pass + @abstractmethod + def set_q_func(self, q_func: ContinuousQFunction) -> None: + pass + class DiscreteQFunctionForwarder(metaclass=ABCMeta): @abstractmethod @@ -104,3 +108,7 @@ def compute_target( self, x: TorchObservation, action: Optional[torch.Tensor] = None ) -> torch.Tensor: pass + + @abstractmethod + def set_q_func(self, q_func: DiscreteQFunction) -> None: + pass diff --git a/d3rlpy/models/torch/q_functions/iqn_q_function.py b/d3rlpy/models/torch/q_functions/iqn_q_function.py index ab892d0f..2cd79ae4 100644 --- a/d3rlpy/models/torch/q_functions/iqn_q_function.py +++ b/d3rlpy/models/torch/q_functions/iqn_q_function.py @@ -170,6 +170,9 @@ def compute_target( return quantiles return pick_quantile_value_by_action(quantiles, action) + def set_q_func(self, q_func: DiscreteQFunction) -> None: + self._q_func = q_func + class ContinuousIQNQFunction(ContinuousQFunction, nn.Module): # type: ignore _encoder: EncoderWithAction @@ -275,3 +278,6 @@ def compute_target( quantiles = self._q_func(x, action).quantiles assert quantiles is not None return quantiles + + def set_q_func(self, q_func: ContinuousQFunction) -> None: + self._q_func = q_func diff --git a/d3rlpy/models/torch/q_functions/mean_q_function.py b/d3rlpy/models/torch/q_functions/mean_q_function.py index 8f1859ba..7f61ff3f 100644 --- a/d3rlpy/models/torch/q_functions/mean_q_function.py +++ b/d3rlpy/models/torch/q_functions/mean_q_function.py @@ -82,6 +82,9 @@ def compute_target( self._q_func(x).q_value, action, keepdim=True ) + def set_q_func(self, q_func: DiscreteQFunction) -> None: + self._q_func = q_func + class ContinuousMeanQFunction(ContinuousQFunction): _encoder: EncoderWithAction @@ -136,3 +139,6 @@ def compute_target( self, x: TorchObservation, action: torch.Tensor ) -> torch.Tensor: return self._q_func(x, action).q_value + + def set_q_func(self, q_func: ContinuousQFunction) -> None: + self._q_func = q_func diff --git a/d3rlpy/models/torch/q_functions/qr_q_function.py b/d3rlpy/models/torch/q_functions/qr_q_function.py index 600c9c90..318e207a 100644 --- a/d3rlpy/models/torch/q_functions/qr_q_function.py +++ b/d3rlpy/models/torch/q_functions/qr_q_function.py @@ -118,6 +118,9 @@ def compute_target( return quantiles return pick_quantile_value_by_action(quantiles, action) + def set_q_func(self, q_func: DiscreteQFunction) -> None: + self._q_func = q_func + class ContinuousQRQFunction(ContinuousQFunction): _encoder: EncoderWithAction @@ -198,3 +201,6 @@ def compute_target( quantiles = self._q_func(x, action).quantiles assert quantiles is not None return quantiles + + def set_q_func(self, q_func: ContinuousQFunction) -> None: + self._q_func = q_func diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index c4d553d9..b077688b 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -71,7 +71,9 @@ class FQEConfig(LearnableConfig): n_critics: int = 1 target_update_interval: int = 100 - def create(self, device: DeviceArg = False) -> "_FQEBase": + def create( + self, device: DeviceArg = False, enable_ddp: bool = False + ) -> "_FQEBase": raise NotImplementedError( "Config object must be directly given to constructor" ) @@ -91,9 +93,10 @@ def __init__( algo: QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig], config: FQEConfig, device: DeviceArg = False, + enable_ddp: bool = False, impl: Optional[FQEBaseImpl] = None, ): - super().__init__(config, device, impl) + super().__init__(config, device, enable_ddp, impl) self._algo = algo def save_policy(self, fname: str) -> None: @@ -152,6 +155,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function( observation_shape, @@ -160,6 +164,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( q_funcs.named_modules(), lr=self._config.learning_rate @@ -228,6 +233,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function( observation_shape, @@ -236,6 +242,7 @@ def inner_create_impl( self._config.q_func_factory, n_ensembles=self._config.n_critics, device=self._device, + enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( q_funcs.named_modules(), lr=self._config.learning_rate diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 1e8990e7..2b231784 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -17,7 +17,6 @@ from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer -from typing_extensions import Self from .dataclass_utils import asdict_without_copy from .dataset import TrajectoryMiniBatch, TransitionMiniBatch @@ -31,6 +30,8 @@ "map_location", "TorchMiniBatch", "TorchTrajectoryMiniBatch", + "wrap_model_by_ddp", + "unwrap_ddp_model", "Checkpointer", "Modules", "convert_to_torch", @@ -302,6 +303,25 @@ def from_batch( ) +_TModule = TypeVar("_TModule", bound=nn.Module) + + +def wrap_model_by_ddp(model: _TModule) -> _TModule: + device_id = next(model.parameters()).device.index + return DDP(model, device_ids=[device_id] if device_id else None) # type: ignore + + +def unwrap_ddp_model(model: _TModule) -> _TModule: + if isinstance(model, DDP): + model = model.module + if isinstance(model, nn.ModuleList): + module_list = nn.ModuleList() + for v in model: + module_list.append(unwrap_ddp_model(v)) + model = module_list + return model + + class Checkpointer: _modules: Dict[str, Union[nn.Module, Optimizer]] _device: str @@ -313,7 +333,12 @@ def __init__( self._device = device def save(self, f: BinaryIO) -> None: - states = {k: v.state_dict() for k, v in self._modules.items()} + # unwrap DDP + modules = { + k: unwrap_ddp_model(v) if isinstance(v, nn.Module) else v + for k, v in self._modules.items() + } + states = {k: v.state_dict() for k, v in modules.items()} torch.save(states, f) def load(self, f: BinaryIO) -> None: @@ -363,23 +388,6 @@ def reset_optimizer_states(self) -> None: if isinstance(v, torch.optim.Optimizer): v.state = collections.defaultdict(dict) - def wrap_models_by_ddp(self) -> Self: - dict_values = asdict_without_copy(self) - for k, v in dict_values.items(): - if isinstance(v, nn.Module): - device_id = next(v.parameters()).device.index - dict_values[k] = DDP( - v, device_ids=[device_id] if device_id else None - ) - return self.__class__(**dict_values) - - def unwrap_models_by_ddp(self) -> Self: - dict_values = asdict_without_copy(self) - for k, v in dict_values.items(): - if isinstance(v, DDP): - dict_values[k] = v.module - return self.__class__(**dict_values) - TCallable = TypeVar("TCallable") diff --git a/examples/custom_algo.py b/examples/custom_algo.py index e2ee2f32..885df3a1 100644 --- a/examples/custom_algo.py +++ b/examples/custom_algo.py @@ -107,8 +107,10 @@ class CustomAlgoConfig(d3rlpy.base.LearnableConfig): target_update_interval: int = 100 gamma: float = 0.99 - def create(self, device: d3rlpy.base.DeviceArg = False) -> "CustomAlgo": - return CustomAlgo(self, device) + def create( + self, device: d3rlpy.base.DeviceArg = False, enable_ddp: bool = False + ) -> "CustomAlgo": + return CustomAlgo(self, device, enable_ddp) @staticmethod def get_type() -> str: diff --git a/examples/distributed_offline_training.py b/examples/distributed_offline_training.py index ba14166a..ba6d7471 100644 --- a/examples/distributed_offline_training.py +++ b/examples/distributed_offline_training.py @@ -27,7 +27,7 @@ def main() -> None: actor_learning_rate=1e-3, critic_learning_rate=1e-3, alpha_learning_rate=1e-3, - ).create(device=device) + ).create(device=device, enable_ddp=True) # prepare dataset dataset, env = d3rlpy.datasets.get_pendulum() @@ -50,7 +50,6 @@ def main() -> None: evaluators=evaluators, logger_adapter=logger_adapter, show_progress=rank == 0, - enable_ddp=True, ) d3rlpy.distributed.destroy_process_group() diff --git a/tests/algos/qlearning/torch/test_utility.py b/tests/algos/qlearning/torch/test_utility.py index bb85b165..1f16be83 100644 --- a/tests/algos/qlearning/torch/test_utility.py +++ b/tests/algos/qlearning/torch/test_utility.py @@ -31,6 +31,7 @@ def test_sample_q_values_with_policy( action_size=action_size, encoder_factory=DummyEncoderFactory(), device="cpu:0", + enable_ddp=False, ) _, q_func_forwarder = create_continuous_q_function( observation_shape=observation_shape, @@ -39,6 +40,7 @@ def test_sample_q_values_with_policy( q_func_factory=MeanQFunctionFactory(), n_ensembles=n_critics, device="cpu:0", + enable_ddp=False, ) observations = create_torch_observations(observation_shape, batch_size) diff --git a/tests/models/test_builders.py b/tests/models/test_builders.py index 3c1855eb..f52a8713 100644 --- a/tests/models/test_builders.py +++ b/tests/models/test_builders.py @@ -51,7 +51,11 @@ def test_create_deterministic_policy( encoder_factory: EncoderFactory, ) -> None: policy = create_deterministic_policy( - observation_shape, action_size, encoder_factory, device="cpu:0" + observation_shape, + action_size, + encoder_factory, + device="cpu:0", + enable_ddp=False, ) assert isinstance(policy, DeterministicPolicy) @@ -74,7 +78,12 @@ def test_create_deterministic_residual_policy( encoder_factory: EncoderFactory, ) -> None: policy = create_deterministic_residual_policy( - observation_shape, action_size, scale, encoder_factory, device="cpu:0" + observation_shape, + action_size, + scale, + encoder_factory, + device="cpu:0", + enable_ddp=False, ) assert isinstance(policy, DeterministicResidualPolicy) @@ -96,7 +105,11 @@ def test_create_normal_policy( encoder_factory: EncoderFactory, ) -> None: policy = create_normal_policy( - observation_shape, action_size, encoder_factory, device="cpu:0" + observation_shape, + action_size, + encoder_factory, + device="cpu:0", + enable_ddp=False, ) assert isinstance(policy, NormalPolicy) @@ -117,7 +130,11 @@ def test_create_categorical_policy( encoder_factory: EncoderFactory, ) -> None: policy = create_categorical_policy( - observation_shape, action_size, encoder_factory, device="cpu:0" + observation_shape, + action_size, + encoder_factory, + device="cpu:0", + enable_ddp=False, ) assert isinstance(policy, CategoricalPolicy) @@ -149,6 +166,7 @@ def test_create_discrete_q_function( encoder_factory, q_func_factory, device="cpu:0", + enable_ddp=False, n_ensembles=n_ensembles, ) @@ -189,6 +207,7 @@ def test_create_continuous_q_function( encoder_factory, q_func_factory, device="cpu:0", + enable_ddp=False, n_ensembles=n_ensembles, ) @@ -226,6 +245,7 @@ def test_create_vae_encoder( latent_size, encoder_factory, device="cpu:0", + enable_ddp=False, ) assert isinstance(vae_encoder, VAEEncoder) @@ -254,6 +274,7 @@ def test_create_vae_decoder( latent_size, encoder_factory, device="cpu:0", + enable_ddp=False, ) assert isinstance(vae_decoder, VAEDecoder) @@ -273,7 +294,7 @@ def test_create_value_function( batch_size: int, ) -> None: v_func = create_value_function( - observation_shape, encoder_factory, device="cpu:0" + observation_shape, encoder_factory, device="cpu:0", enable_ddp=False ) assert isinstance(v_func, ValueFunction) @@ -286,7 +307,7 @@ def test_create_value_function( @pytest.mark.parametrize("shape", [(100,)]) def test_create_parameter(shape: Sequence[int]) -> None: x = np.random.random() - parameter = create_parameter(shape, x, device="cpu:0") + parameter = create_parameter(shape, x, device="cpu:0", enable_ddp=False) assert len(list(parameter.parameters())) == 1 assert np.allclose(get_parameter(parameter).detach().numpy(), x) @@ -333,6 +354,7 @@ def test_create_continuous_decision_transformer( activation_type=activation_type, position_encoding_type=position_encoding_type, device="cpu:0", + enable_ddp=False, ) assert isinstance(transformer, ContinuousDecisionTransformer) @@ -388,6 +410,7 @@ def test_create_discrete_decision_transformer( embed_activation_type=activation_type, position_encoding_type=position_encoding_type, device="cpu:0", + enable_ddp=False, ) assert isinstance(transformer, DiscreteDecisionTransformer) diff --git a/tests/models/torch/test_q_functions.py b/tests/models/torch/test_q_functions.py index b715113d..ab054765 100644 --- a/tests/models/torch/test_q_functions.py +++ b/tests/models/torch/test_q_functions.py @@ -41,6 +41,7 @@ def test_compute_max_with_n_actions( q_func_factory, n_ensembles=n_ensembles, device="cpu:0", + enable_ddp=False, ) x = create_torch_observations(observation_shape, batch_size) actions = torch.rand(batch_size, n_actions, action_size)