Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Classic Control Gymnax #90

Open
wants to merge 39 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6802e04
Add gymnax folder
Apr 20, 2023
10b1e85
Update requirements
Apr 20, 2023
ae2caf4
add CARLJaxPendulumEnv
Arman717 May 22, 2023
694241e
add init
Arman717 May 22, 2023
f1ec7b3
add CARLJaxAcrobotEnv
Arman717 May 22, 2023
3ed010d
remove comment
Arman717 May 22, 2023
9fd4ea6
add CARLJaxCartPoleEnv
Arman717 May 22, 2023
5b6eee5
add CARLJaxMountainCarEnv
Arman717 May 22, 2023
4fa6f98
add CARLJaxMountainCarContinuousEnv
Arman717 May 22, 2023
68048f5
fix setup
Arman717 May 22, 2023
666a323
Rename file
May 23, 2023
ffc2bb9
Add imports
May 23, 2023
9a3d416
add modified version for GymnaxToGymWrapper
Arman717 May 27, 2023
5874515
make precommit
Arman717 May 27, 2023
92458e0
make precommit
Arman717 May 27, 2023
9de0e29
Fix gymnax
Jun 7, 2023
0bba185
Fix pre-commit
Jun 7, 2023
289d324
Merge branch 'development' into gymnax
benjamc Dec 11, 2023
97ae965
Move files
benjamc Dec 11, 2023
c0c9620
Rename envs and imports
benjamc Dec 11, 2023
2406616
Rename files
benjamc Dec 11, 2023
63cba37
Fix env comapts
benjamc Dec 11, 2023
689a2c8
Reformat
benjamc Dec 11, 2023
7ec0dfc
Smarter update context
benjamc Dec 11, 2023
adde740
Update cartpole
benjamc Dec 11, 2023
52bd501
Update mountaincar
benjamc Dec 11, 2023
9874e52
Fix context space
benjamc Dec 11, 2023
dd0c430
Update pendulum
benjamc Dec 11, 2023
86e1737
Format
benjamc Dec 11, 2023
72b82e1
Update changelog.md
benjamc Dec 11, 2023
1a0f985
Fix context space
benjamc Dec 11, 2023
a35e087
Fix context space
benjamc Dec 11, 2023
39782de
Fix context space
benjamc Dec 11, 2023
10aa406
Fix context space
benjamc Dec 11, 2023
db292e5
Merge branch 'development' into gymnax
benjamc Dec 11, 2023
7df9d72
Remove warnings
benjamc Dec 11, 2023
134f161
Fix pre-commit
benjamc Dec 11, 2023
778e383
Fix context space
benjamc Dec 11, 2023
7c2e21e
Fix pre-commit
benjamc Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions carl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,13 @@
warnings.warn(
"Module 'dm_control' not found. If you want to use these environments, please follow the installation guide."
)


gymnax_spec = iutil.find_spec("gymnax")
found = gymnax_spec is not None
if found:
from carl.envs.gymnax import *
else:
warnings.warn(
"Module 'gymnax' not found. If you want to use these environments, please follow the installation guide."
)
36 changes: 36 additions & 0 deletions carl/envs/gymnax/__init__.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call them Jax and not Gymnax? I bet there are other attempts at Jaxing them, Gymnax would be more explicit imo

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# flake8: noqa: F401
from carl.envs.gymnax.carl_jax_acrobot import CONTEXT_BOUNDS as CARLJaxAcrobotEnv_bounds
from carl.envs.gymnax.carl_jax_acrobot import (
DEFAULT_CONTEXT as CARLJaxAcrobotEnv_defaults,
)
from carl.envs.gymnax.carl_jax_acrobot import CARLJaxAcrobotEnv
from carl.envs.gymnax.carl_jax_cartpole import (
CONTEXT_BOUNDS as CARLJaxCartPoleEnv_bounds,
)
from carl.envs.gymnax.carl_jax_cartpole import (
DEFAULT_CONTEXT as CARLJaxCartPoleEnv_defaults,
)
from carl.envs.gymnax.carl_jax_cartpole import CARLJaxCartPoleEnv
from carl.envs.gymnax.carl_jax_mountaincar import (
CONTEXT_BOUNDS as CARLJaxMountainCarContinuousEnv_bounds,
)
from carl.envs.gymnax.carl_jax_mountaincar import (
CONTEXT_BOUNDS as CARLJaxMountainCarEnv_bounds,
)
from carl.envs.gymnax.carl_jax_mountaincar import (
DEFAULT_CONTEXT as CARLJaxMountainCarContinuousEnv_defaults,
)
from carl.envs.gymnax.carl_jax_mountaincar import (
DEFAULT_CONTEXT as CARLJaxMountainCarEnv_defaults,
)
from carl.envs.gymnax.carl_jax_mountaincar import (
CARLJaxMountainCarContinuousEnv,
CARLJaxMountainCarEnv,
)
from carl.envs.gymnax.carl_jax_pendulum import (
CONTEXT_BOUNDS as CARLJaxPendulumEnv_bounds,
)
from carl.envs.gymnax.carl_jax_pendulum import (
DEFAULT_CONTEXT as CARLJaxPendulumEnv_defaults,
)
from carl.envs.gymnax.carl_jax_pendulum import CARLJaxPendulumEnv
85 changes: 85 additions & 0 deletions carl/envs/gymnax/carl_gymnax_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

from typing import Any, Dict, List, Optional, Union

import gymnasium

from CARL.carl.envs.gymnax.utils import make_gymnax_env
from carl.context.selection import AbstractSelector
from carl.envs.carl_env import CARLEnv
from carl.utils.trial_logger import TrialLogger
from carl.utils.types import Context, Contexts


class CARLGymnaxEnv(CARLEnv):
env_name: str
DEFAULT_CONTEXT: Context
max_episode_steps: int

def __init__(
self,
env: gymnasium.Env | None = None,
contexts: Contexts = {},
hide_context: bool = True,
add_gaussian_noise_to_context: bool = False,
gaussian_noise_std_percentage: float = 0.01,
logger: Optional[TrialLogger] = None,
scale_context_features: str = "no",
default_context: Optional[Context] = None,
state_context_features: Optional[List[str]] = None,
context_mask: Optional[List[str]] = None,
dict_observation_space: bool = False,
context_selector: Optional[
Union[AbstractSelector, type[AbstractSelector]]
] = None,
context_selector_kwargs: Optional[Dict] = None,
):
"""
Max torque is not a context feature because it changes the action space.

Parameters
----------
env
contexts
instance_mode
hide_context
add_gaussian_noise_to_context
gaussian_noise_std_percentage
"""
if env is None:
env = make_gymnax_env(env_name=self.env_name)

if not contexts:
contexts = {0: self.DEFAULT_CONTEXT}

if not default_context:
default_context = self.DEFAULT_CONTEXT

super().__init__(
env=env,
contexts=contexts,
hide_context=hide_context,
add_gaussian_noise_to_context=add_gaussian_noise_to_context,
gaussian_noise_std_percentage=gaussian_noise_std_percentage,
logger=logger,
scale_context_features=scale_context_features,
default_context=default_context,
max_episode_length=self.max_episode_steps,
state_context_features=state_context_features,
dict_observation_space=dict_observation_space,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
context_mask=context_mask,
)
self.whitelist_gaussian_noise = list(
self.DEFAULT_CONTEXT.keys()
) # allow to augment all values

def _update_context(self) -> None:
raise NotImplementedError

def __getattr__(self, name: str) -> Any:
if name in ["sys", "__getstate__"]:
return getattr(self.env._environment, name)
else:
return getattr(self, name)
88 changes: 88 additions & 0 deletions carl/envs/gymnax/carl_jax_acrobot.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Env params is frozen for a reason, but I guess this runs, right? Pretty sure we can't jit it, but we have this issue with brax as well... at some point we might want to brainstorm if we can somehown avoid replacing env parts to change context. For now it's alright, though.

Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import gymnax
import jax.numpy as jnp
import numpy as np

from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv
from carl.utils.types import Context

DEFAULT_CONTEXT = {
"link_length_1": 1,
"link_length_2": 1,
"link_mass_1": 1,
"link_mass_2": 1,
"link_com_pos_1": 0.5,
"link_com_pos_2": 0.5,
"link_moi": 1,
"max_vel_1": 4 * jnp.pi,
"max_vel_2": 9 * jnp.pi,
"torque_noise_max": 0.0,
"max_steps_in_episode": 500,
}

CONTEXT_BOUNDS = {
"link_length_1": (
0.1,
10,
float,
), # Links can be shrunken and grown by a factor of 10
"link_length_2": (0.1, 10, float),
"link_mass_1": (
0.1,
10,
float,
), # Link mass can be shrunken and grown by a factor of 10
"link_mass_2": (0.1, 10, float),
"link_com_pos_1": (
0,
1,
float,
), # Center of mass can move from one end to the other
"link_com_pos_2": (0, 1, float),
"link_moi": (
0.1,
10,
float,
), # Moments on inertia can be shrunken and grown by a factor of 10
"max_vel_1": (
0.4 * np.pi,
40 * np.pi,
float,
), # Velocity can vary by a factor of 10 in either direction
"max_vel_2": (0.9 * np.pi, 90 * np.pi, float),
"torque_noise_max": (
-1.0,
1.0,
float,
), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme
"max_steps_in_episode": (1, jnp.inf, int),
}


class CARLJaxAcrobotEnv(CARLGymnaxEnv):
env_name: str = "Acrobot-v1"
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"])
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT

def _update_context(self) -> None:
content = self.env.env.env_params.__dict__
content.update(self.context)
# We cannot directly set attributes of env_params because it is a frozen dataclass
self.env.env.env_params = gymnax.environments.classic_control.acrobot.EnvParams(
**content
)

high = jnp.array(
[
1.0,
1.0,
1.0,
1.0,
self.env.env.env_params.max_vel_1,
self.env.env.env_params.max_vel_2,
],
dtype=jnp.float32,
)
low = -high
self.build_observation_space(low, high, CONTEXT_BOUNDS)
59 changes: 59 additions & 0 deletions carl/envs/gymnax/carl_jax_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import gymnax
import jax.numpy as jnp

from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv
from carl.utils.types import Context

DEFAULT_CONTEXT = {
"gravity": 9.8,
"masscart": 1.0,
"masspole": 0.1,
"length": 0.5,
"force_mag": 10.0,
"tau": 0.02,
"polemass_length": None,
"total_mass": None,
"max_steps_in_episode": 500,
}

CONTEXT_BOUNDS = {
"gravity": (5.0, 15.0, float),
"masscart": (0.5, 2.0, float),
"masspole": (0.05, 0.2, float),
"length": (0.25, 1.0, float),
"force_mag": (5.0, 15.0, float),
"tau": (0.01, 0.05, float),
"polemass_length": (0, jnp.inf, float),
"total_mass": (0, jnp.inf, float),
"max_steps_in_episode": (1, jnp.inf, int),
}


class CARLJaxCartPoleEnv(CARLGymnaxEnv):
env_name: str = "CartPole-v1"
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"]) # type: ignore[arg-type]
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT

def _update_context(self) -> None:
self.context["polemass_length"] = (
self.context["masspole"] * self.context["length"]
)
self.context["total_mass"] = self.context["masscart"] + self.context["masspole"]

self.env.env.env_params = (
gymnax.environments.classic_control.cartpole.EnvParams(**self.context)
)

high = jnp.array(
[
self.env.env.env_params.x_threshold * 2,
jnp.finfo(jnp.float32).max,
self.env.env.env_params.theta_threshold_radians * 2,
jnp.finfo(jnp.float32).max,
],
dtype=jnp.float32,
)
low = -high
self.build_observation_space(low, high, CONTEXT_BOUNDS)
57 changes: 57 additions & 0 deletions carl/envs/gymnax/carl_jax_mountaincar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import gymnax
import jax.numpy as jnp

from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv
from carl.utils.types import Context

DEFAULT_CONTEXT = {
"min_position": -1.2,
"max_position": 0.6,
"max_speed": 0.07,
"goal_position": 0.5,
"goal_velocity": 0,
"force": 0.001,
"gravity": 0.0025,
"max_steps_in_episode": 200,
}

CONTEXT_BOUNDS = {
"min_position": (-jnp.inf, jnp.inf, float),
"max_position": (-jnp.inf, jnp.inf, float),
"max_speed": (0, jnp.inf, float),
"goal_position": (-jnp.inf, jnp.inf, float),
"goal_velocity": (-jnp.inf, jnp.inf, float),
"force": (-jnp.inf, jnp.inf, float),
"gravity": (0, jnp.inf, float),
"max_steps_in_episode": (1, jnp.inf, int),
}


class CARLJaxMountainCarEnv(CARLGymnaxEnv):
env_name: str = "MountainCar-v0"
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"])
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT

def _update_context(self) -> None:
self.env.env.env_params = (
gymnax.environments.classic_control.mountain_car.EnvParams(**self.context)
)

self.low = jnp.array(
[self.env.env.env_params.min_position, -self.env.env.env_params.max_speed],
dtype=jnp.float32,
).squeeze()
self.high = jnp.array(
[self.env.env.env_params.max_position, self.env.env.env_params.max_speed],
dtype=jnp.float32,
).squeeze()

self.build_observation_space(self.low, self.high, CONTEXT_BOUNDS)


class CARLJaxMountainCarContinuousEnv(CARLJaxMountainCarEnv):
env_name: str = "MountainCarContinuous-v0"
max_episode_steps: int = 999
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT
41 changes: 41 additions & 0 deletions carl/envs/gymnax/carl_jax_pendulum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import jax.numpy as jnp
from gymnax.environments.classic_control.pendulum import EnvParams

from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv
from carl.utils.types import Context

DEFAULT_CONTEXT = {
"max_speed": 8.0,
"max_torque": 2.0,
"dt": 0.05,
"g": 10.0,
"m": 1.0,
"l": 1.0,
"max_steps_in_episode": 200,
}

CONTEXT_BOUNDS = {
"max_speed": (-jnp.inf, jnp.inf, float),
"max_torque": (-jnp.inf, jnp.inf, float),
"dt": (0, jnp.inf, float),
"g": (0, jnp.inf, float),
"m": (1e-6, jnp.inf, float),
"l": (1e-6, jnp.inf, float),
"max_steps_in_episode": (1, jnp.inf, int),
}


class CARLJaxPendulumEnv(CARLGymnaxEnv):
env_name: str = "Pendulum-v1"
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"])
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT

def _update_context(self) -> None:
self.env.env.env_params = EnvParams(**self.context)

high = jnp.array(
[1.0, 1.0, self.env.env.env_params.max_speed], dtype=jnp.float32
)
self.build_observation_space(-high, high, CONTEXT_BOUNDS)
Loading