diff --git a/skrl/agents/jax/base.py b/skrl/agents/jax/base.py index 561d2701..fbcb9fa5 100644 --- a/skrl/agents/jax/base.py +++ b/skrl/agents/jax/base.py @@ -53,7 +53,10 @@ def __init__(self, if device is None: self.device = jax.devices()[0] else: - self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0] + self.device = device + if type(device) == str: + device_type, device_index = f"{device}:0".split(':')[:2] + self.device = jax.devices(device_type)[int(device_index)] if type(memory) is list: self.memory = memory[0] diff --git a/skrl/envs/wrappers/jax/base.py b/skrl/envs/wrappers/jax/base.py index 0be8b742..b1136a75 100644 --- a/skrl/envs/wrappers/jax/base.py +++ b/skrl/envs/wrappers/jax/base.py @@ -20,12 +20,19 @@ def __init__(self, env: Any) -> None: self._env = env # device (faster than @property) - self.device = jax.devices()[0] + self.device = None if hasattr(self._env, "device"): - try: - self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0] - except RuntimeError: - pass + if type(self._env.device) == str: + device_type, device_index = f"{self._env.device}:0".split(':')[:2] + try: + self.device = jax.devices(device_type)[int(device_index)] + except RuntimeError: + self.device = None + else: + self.device = self._env.device + if self.device is None: + self.device = jax.devices()[0] + # spaces try: self._action_space = self._env.single_action_space @@ -135,12 +142,18 @@ def __init__(self, env: Any) -> None: self._env = env # device (faster than @property) - self.device = jax.devices()[0] + self.device = None if hasattr(self._env, "device"): - try: - self.device = jax.devices(self._env.device.split(':')[0] if type(self._env.device) == str else self._env.device.type)[0] - except RuntimeError: - pass + if type(self._env.device) == str: + device_type, device_index = f"{self._env.device}:0".split(':')[:2] + try: + self.device = jax.devices(device_type)[int(device_index)] + except RuntimeError: + self.device = None + else: + self.device = self._env.device + if self.device is None: + self.device = jax.devices()[0] self.possible_agents = [] diff --git a/skrl/memories/jax/base.py b/skrl/memories/jax/base.py index 41cea1b1..90783836 100644 --- a/skrl/memories/jax/base.py +++ b/skrl/memories/jax/base.py @@ -70,7 +70,10 @@ def __init__(self, if device is None: self.device = jax.devices()[0] else: - self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0] + self.device = device + if type(device) == str: + device_type, device_index = f"{device}:0".split(':')[:2] + self.device = jax.devices(device_type)[int(device_index)] # internal variables self.filled = False diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index 30db0e83..7d121e25 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -79,7 +79,10 @@ def __call__(self, inputs, role): if device is None: self.device = jax.devices()[0] else: - self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0] + self.device = device + if type(device) == str: + device_type, device_index = f"{device}:0".split(':')[:2] + self.device = jax.devices(device_type)[int(device_index)] self.observation_space = observation_space self.action_space = action_space diff --git a/skrl/multi_agents/jax/base.py b/skrl/multi_agents/jax/base.py index 77619b81..e177610d 100644 --- a/skrl/multi_agents/jax/base.py +++ b/skrl/multi_agents/jax/base.py @@ -60,7 +60,10 @@ def __init__(self, if device is None: self.device = jax.devices()[0] else: - self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0] + self.device = device + if type(device) == str: + device_type, device_index = f"{device}:0".split(':')[:2] + self.device = jax.devices(device_type)[int(device_index)] # convert the models to their respective device for _models in self.models.values(): diff --git a/skrl/resources/noises/jax/base.py b/skrl/resources/noises/jax/base.py index 2df54e20..e51e6fd3 100644 --- a/skrl/resources/noises/jax/base.py +++ b/skrl/resources/noises/jax/base.py @@ -31,7 +31,10 @@ def sample(self, size): if device is None: self.device = jax.devices()[0] else: - self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0] + self.device = device + if type(device) == str: + device_type, device_index = f"{device}:0".split(':')[:2] + self.device = jax.devices(device_type)[int(device_index)] def sample_like(self, tensor: Union[np.ndarray, jax.Array]) -> Union[np.ndarray, jax.Array]: """Sample a noise with the same size (shape) as the input tensor diff --git a/skrl/resources/preprocessors/jax/running_standard_scaler.py b/skrl/resources/preprocessors/jax/running_standard_scaler.py index 1d319605..ace65f06 100644 --- a/skrl/resources/preprocessors/jax/running_standard_scaler.py +++ b/skrl/resources/preprocessors/jax/running_standard_scaler.py @@ -95,7 +95,10 @@ def __init__(self, if device is None: self.device = jax.devices()[0] else: - self.device = device if isinstance(device, jax.Device) else jax.devices(device)[0] + self.device = device + if type(device) == str: + device_type, device_index = f"{device}:0".split(':')[:2] + self.device = jax.devices(device_type)[int(device_index)] size = self._get_space_size(size)