Skip to content

Commit

Permalink
Redesign data parallel distributed support
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Sep 21, 2024
1 parent 5ab2b46 commit 8e579b6
Show file tree
Hide file tree
Showing 33 changed files with 344 additions and 114 deletions.
9 changes: 7 additions & 2 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ class AWACConfig(LearnableConfig):
n_action_samples: int = 1
n_critics: int = 2

def create(self, device: DeviceArg = False) -> "AWAC":
return AWAC(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "AWAC":
return AWAC(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -106,6 +108,7 @@ def inner_create_impl(
max_logstd=0.0,
use_std_parameter=True,
device=self._device,
enable_ddp=self._enable_ddp,
)
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -114,6 +117,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -122,6 +126,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)

actor_optim = self._config.actor_optim_factory.create(
Expand Down
10 changes: 0 additions & 10 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def fit(
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
enable_ddp: bool = False,
) -> List[Tuple[int, Dict[str, float]]]:
"""Trains with given dataset.
Expand Down Expand Up @@ -414,7 +413,6 @@ def fit(
epoch_callback: Callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
enable_ddp: Flag to wrap models with DataDistributedParallel.
Returns:
List of result tuples (epoch, metrics) per epoch.
Expand All @@ -434,7 +432,6 @@ def fit(
evaluators=evaluators,
callback=callback,
epoch_callback=epoch_callback,
enable_ddp=enable_ddp,
)
)
return results
Expand All @@ -454,7 +451,6 @@ def fitter(
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
enable_ddp: bool = False,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
Expand Down Expand Up @@ -486,7 +482,6 @@ def fitter(
epoch_callback: Callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
enable_ddp: Flag to wrap models with DataDistributedParallel.
Returns:
Iterator yielding current epoch and metrics dict.
Expand Down Expand Up @@ -522,11 +517,6 @@ def fitter(
else:
LOG.warning("Skip building models since they're already built.")

# wrap all PyTorch modules with DataDistributedParallel
if enable_ddp:
assert self._impl
self._impl.wrap_models_by_ddp()

# save hyperparameters
save_config(self, logger)

Expand Down
15 changes: 11 additions & 4 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ class BCConfig(LearnableConfig):
optim_factory: OptimizerFactory = make_optimizer_field()
encoder_factory: EncoderFactory = make_encoder_field()

def create(self, device: DeviceArg = False) -> "BC":
return BC(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "BC":
return BC(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -75,6 +77,7 @@ def inner_create_impl(
action_size,
self._config.encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
elif self._config.policy_type == "stochastic":
imitator = create_normal_policy(
Expand All @@ -84,6 +87,7 @@ def inner_create_impl(
min_logstd=-4.0,
max_logstd=15.0,
device=self._device,
enable_ddp=self._enable_ddp,
)
else:
raise ValueError(f"invalid policy_type: {self._config.policy_type}")
Expand Down Expand Up @@ -141,8 +145,10 @@ class DiscreteBCConfig(LearnableConfig):
encoder_factory: EncoderFactory = make_encoder_field()
beta: float = 0.5

def create(self, device: DeviceArg = False) -> "DiscreteBC":
return DiscreteBC(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "DiscreteBC":
return DiscreteBC(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -158,6 +164,7 @@ def inner_create_impl(
action_size,
self._config.encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)

optim = self._config.optim_factory.create(
Expand Down
21 changes: 17 additions & 4 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ class BCQConfig(LearnableConfig):
rl_start_step: int = 0
beta: float = 0.5

def create(self, device: DeviceArg = False) -> "BCQ":
return BCQ(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "BCQ":
return BCQ(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -178,13 +180,15 @@ def inner_create_impl(
self._config.action_flexibility,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_policy = create_deterministic_residual_policy(
observation_shape,
action_size,
self._config.action_flexibility,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -193,6 +197,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -201,6 +206,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
vae_encoder = create_vae_encoder(
observation_shape=observation_shape,
Expand All @@ -210,13 +216,15 @@ def inner_create_impl(
max_logstd=15.0,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
vae_decoder = create_vae_decoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)

actor_optim = self._config.actor_optim_factory.create(
Expand Down Expand Up @@ -337,8 +345,10 @@ class DiscreteBCQConfig(LearnableConfig):
target_update_interval: int = 8000
share_encoder: bool = True

def create(self, device: DeviceArg = False) -> "DiscreteBCQ":
return DiscreteBCQ(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "DiscreteBCQ":
return DiscreteBCQ(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -356,6 +366,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_discrete_q_function(
observation_shape,
Expand All @@ -364,6 +375,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)

# share convolutional layers if observation is pixel
Expand All @@ -384,6 +396,7 @@ def inner_create_impl(
action_size,
self._config.encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)

q_func_params = list(q_funcs.named_modules())
Expand Down
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,10 @@ class BEARConfig(LearnableConfig):
vae_kl_weight: float = 0.5
warmup_steps: int = 40000

def create(self, device: DeviceArg = False) -> "BEAR":
return BEAR(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "BEAR":
return BEAR(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -163,6 +165,7 @@ def inner_create_impl(
action_size,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
q_funcs, q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -171,6 +174,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -179,6 +183,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
vae_encoder = create_vae_encoder(
observation_shape=observation_shape,
Expand All @@ -188,21 +193,27 @@ def inner_create_impl(
max_logstd=15.0,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
vae_decoder = create_vae_decoder(
observation_shape=observation_shape,
action_size=action_size,
latent_size=2 * action_size,
encoder_factory=self._config.imitator_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
log_temp = create_parameter(
(1, 1),
math.log(self._config.initial_temperature),
device=self._device,
enable_ddp=self._enable_ddp,
)
log_alpha = create_parameter(
(1, 1), math.log(self._config.initial_alpha), device=self._device
(1, 1),
math.log(self._config.initial_alpha),
device=self._device,
enable_ddp=self._enable_ddp,
)

actor_optim = self._config.actor_optim_factory.create(
Expand Down
15 changes: 12 additions & 3 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ class CalQLConfig(CQLConfig):
max_q_backup (bool): Flag to sample max Q-values for target.
"""

def create(self, device: DeviceArg = False) -> "CalQL":
return CalQL(self, device)
def create(
self, device: DeviceArg = False, enable_ddp: bool = False
) -> "CalQL":
return CalQL(self, device, enable_ddp)

@staticmethod
def get_type() -> str:
Expand All @@ -92,6 +94,7 @@ def inner_create_impl(
action_size,
self._config.actor_encoder_factory,
device=self._device,
enable_ddp=self._enable_ddp,
)
q_funcs, q_func_fowarder = create_continuous_q_function(
observation_shape,
Expand All @@ -100,6 +103,7 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
targ_q_funcs, targ_q_func_forwarder = create_continuous_q_function(
observation_shape,
Expand All @@ -108,14 +112,19 @@ def inner_create_impl(
self._config.q_func_factory,
n_ensembles=self._config.n_critics,
device=self._device,
enable_ddp=self._enable_ddp,
)
log_temp = create_parameter(
(1, 1),
math.log(self._config.initial_temperature),
device=self._device,
enable_ddp=self._enable_ddp,
)
log_alpha = create_parameter(
(1, 1), math.log(self._config.initial_alpha), device=self._device
(1, 1),
math.log(self._config.initial_alpha),
device=self._device,
enable_ddp=self._enable_ddp,
)

actor_optim = self._config.actor_optim_factory.create(
Expand Down
Loading

0 comments on commit 8e579b6

Please sign in to comment.