Skip to content

Commit

Permalink
Apply black and codespell pre-commit hooks (#222)
Browse files Browse the repository at this point in the history
* Configure pre-commit hooks

* Configure black in pre-commit and pyproject.toml

* Configure codespell in pre-commit

* Ignore formating config sections

* Apply black forma to skrl folder

* Apply black format to tests folder

* Apply codespell

* Update CHANGELOG
  • Loading branch information
Toni-SM authored Nov 5, 2024
1 parent 88ac11f commit 4db2956
Show file tree
Hide file tree
Showing 174 changed files with 7,058 additions and 4,100 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@ repos:
- id: end-of-file-fixer
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: ^(tests/strategies.py|tests/utils.py)
- id: no-commit-to-branch
- id: trailing-whitespace
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
exclude: ^(docs/source/_static|docs/_build|pyproject.toml)
additional_dependencies:
- tomli
- repo: https://github.com/python/black
rev: 24.8.0
hooks:
- id: black
args: ["--line-length=120"]
exclude: ^(docs/)
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0)
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9).

### Changed (breaking changes: style)
- Format code using Black code formatter (it's ugly, yes, but it does its job)

### Fixed
- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/agents/sarsa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ State Action Reward State Action (SARSA)

SARSA is a **model-free** **on-policy** algorithm that uses a **tabular** Q-function to handle **discrete** observations and action spaces

Paper: `On-Line Q-Learning Using Connectionist Systems <https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.17.2539>`_
Paper: `On-Line Q-Learning Using Connectionist Systems <https://scholar.google.com/scholar?q=On-line+Q-learning+using+connectionist+system>`_

.. raw:: html

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# end current motion
if self.motion is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# go to 1) safe position, 2) random position
self.robot.command_joint_position(self.robot_default_dof_pos)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def _callback_joint_states(self, msg):
self.robot_state["joint_velocity"] = np.array(msg.velocity)

def _callback_end_effector_pose(self, msg):
positon = msg.position
self.robot_state["cartesian_position"] = np.array([positon.x, positon.y, positon.z])
position = msg.position
self.robot_state["cartesian_position"] = np.array([position.x, position.y, position.z])

def _get_observation_reward_done(self):
# observation
Expand Down Expand Up @@ -146,7 +146,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# go to 1) safe position, 2) random position
msg = sensor_msgs.msg.JointState()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def _callback_joint_states(self, msg):
self.robot_state["joint_velocity"] = np.array(msg.velocity)

def _callback_end_effector_pose(self, msg):
positon = msg.position
self.robot_state["cartesian_position"] = np.array([positon.x, positon.y, positon.z])
position = msg.position
self.robot_state["cartesian_position"] = np.array([position.x, position.y, position.z])

def _get_observation_reward_done(self):
# observation
Expand Down Expand Up @@ -123,7 +123,7 @@ def _get_observation_reward_done(self):
return self.obs_buf, reward, done

def reset(self):
print("Reseting...")
print("Resetting...")

# go to 1) safe position, 2) random position
msg = sensor_msgs.msg.JointState()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/utils/tensorboard_file_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
mean = np.mean(rewards[:,:,1], axis=0)
std = np.std(rewards[:,:,1], axis=0)

# creae two subplots (one for each reward and one for the mean)
# create two subplots (one for each reward and one for the mean)
fig, ax = plt.subplots(1, 2, figsize=(15, 5))

# plot the rewards for each experiment
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ tests = [

[tool.black]
line-length = 120
extend-exclude = """
(
^/docs
)
"""


[tool.codespell]
Expand Down
57 changes: 35 additions & 22 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# read library version from metadata
try:
import importlib.metadata

__version__ = importlib.metadata.version("skrl")
except ImportError:
__version__ = "unknown"
Expand All @@ -21,15 +22,18 @@
# logger with format
class _Formatter(logging.Formatter):
_format = "[%(name)s:%(levelname)s] %(message)s"
_formats = {logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m",
logging.INFO: f"\x1b[38;20m{_format}\x1b[0m",
logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m",
logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m",
logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m"}
_formats = {
logging.DEBUG: f"\x1b[38;20m{_format}\x1b[0m",
logging.INFO: f"\x1b[38;20m{_format}\x1b[0m",
logging.WARNING: f"\x1b[33;20m{_format}\x1b[0m",
logging.ERROR: f"\x1b[31;20m{_format}\x1b[0m",
logging.CRITICAL: f"\x1b[31;1m{_format}\x1b[0m",
}

def format(self, record):
return logging.Formatter(self._formats.get(record.levelno)).format(record)


_handler = logging.StreamHandler()
_handler.setLevel(logging.DEBUG)
_handler.setFormatter(_Formatter())
Expand All @@ -42,13 +46,11 @@ def format(self, record):
# machine learning framework configuration
class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""
"""Machine learning framework specific configuration"""

class PyTorch(object):
def __init__(self) -> None:
"""PyTorch configuration
"""
"""PyTorch configuration"""
self._device = None
# torch.distributed config
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand All @@ -59,7 +61,10 @@ def __init__(self) -> None:
# set up distributed runs
if self._is_distributed:
import torch
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")

logger.info(
f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})"
)
torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size)
torch.cuda.set_device(self._local_rank)

Expand All @@ -72,6 +77,7 @@ def device(self) -> "torch.device":
"""
try:
import torch

if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
Expand Down Expand Up @@ -116,8 +122,7 @@ def is_distributed(self) -> bool:

class JAX(object):
def __init__(self) -> None:
"""JAX configuration
"""
"""JAX configuration"""
self._backend = "numpy"
self._key = np.array([0, 0], dtype=np.uint32)
# distributed config (based on torch.distributed, since JAX doesn't implement it)
Expand All @@ -126,19 +131,26 @@ def __init__(self) -> None:
self._local_rank = int(os.getenv("JAX_LOCAL_RANK", "0"))
self._rank = int(os.getenv("JAX_RANK", "0"))
self._world_size = int(os.getenv("JAX_WORLD_SIZE", "1"))
self._coordinator_address = os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
self._coordinator_address = (
os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
)
self._is_distributed = self._world_size > 1
# device
self._device = f"cuda:{self._local_rank}"

# set up distributed runs
if self._is_distributed:
import jax
logger.info(f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})")
jax.distributed.initialize(coordinator_address=self._coordinator_address,
num_processes=self._world_size,
process_id=self._rank,
local_device_ids=self._local_rank)

logger.info(
f"Distributed (rank: {self._rank}, local rank: {self._local_rank}, world size: {self._world_size})"
)
jax.distributed.initialize(
coordinator_address=self._coordinator_address,
num_processes=self._world_size,
process_id=self._rank,
local_device_ids=self._local_rank,
)

@staticmethod
def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
Expand All @@ -148,7 +160,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
This function supports the PyTorch-like ``"type:ordinal"`` string specification (e.g.: ``"cuda:0"``).
:param device: Device specification. If the specified device is ``None`` ot it cannot be resolved,
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.
:return: JAX Device.
Expand All @@ -158,7 +170,7 @@ def parse_device(device: Union[str, "jax.Device", None]) -> "jax.Device":
if isinstance(device, jax.Device):
return device
elif isinstance(device, str):
device_type, device_index = f"{device}:0".split(':')[:2]
device_type, device_index = f"{device}:0".split(":")[:2]
try:
return jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError) as e:
Expand Down Expand Up @@ -196,11 +208,11 @@ def backend(self, value: str) -> None:

@property
def key(self) -> "jax.Array":
"""Pseudo-random number generator (PRNG) key
"""
"""Pseudo-random number generator (PRNG) key"""
if isinstance(self._key, np.ndarray):
try:
import jax

with jax.default_device(self.device):
self._key = jax.random.PRNGKey(self._key[1])
except ImportError:
Expand Down Expand Up @@ -257,4 +269,5 @@ def is_distributed(self) -> bool:
self.jax = JAX()
self.torch = PyTorch()


config = _Config()
Loading

0 comments on commit 4db2956

Please sign in to comment.