Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Classic Control Gymnax #90

Open
wants to merge 39 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6802e04
Add gymnax folder
Apr 20, 2023
10b1e85
Update requirements
Apr 20, 2023
ae2caf4
add CARLJaxPendulumEnv
Arman717 May 22, 2023
694241e
add init
Arman717 May 22, 2023
f1ec7b3
add CARLJaxAcrobotEnv
Arman717 May 22, 2023
3ed010d
remove comment
Arman717 May 22, 2023
9fd4ea6
add CARLJaxCartPoleEnv
Arman717 May 22, 2023
5b6eee5
add CARLJaxMountainCarEnv
Arman717 May 22, 2023
4fa6f98
add CARLJaxMountainCarContinuousEnv
Arman717 May 22, 2023
68048f5
fix setup
Arman717 May 22, 2023
666a323
Rename file
May 23, 2023
ffc2bb9
Add imports
May 23, 2023
9a3d416
add modified version for GymnaxToGymWrapper
Arman717 May 27, 2023
5874515
make precommit
Arman717 May 27, 2023
92458e0
make precommit
Arman717 May 27, 2023
9de0e29
Fix gymnax
Jun 7, 2023
0bba185
Fix pre-commit
Jun 7, 2023
289d324
Merge branch 'development' into gymnax
benjamc Dec 11, 2023
97ae965
Move files
benjamc Dec 11, 2023
c0c9620
Rename envs and imports
benjamc Dec 11, 2023
2406616
Rename files
benjamc Dec 11, 2023
63cba37
Fix env comapts
benjamc Dec 11, 2023
689a2c8
Reformat
benjamc Dec 11, 2023
7ec0dfc
Smarter update context
benjamc Dec 11, 2023
adde740
Update cartpole
benjamc Dec 11, 2023
52bd501
Update mountaincar
benjamc Dec 11, 2023
9874e52
Fix context space
benjamc Dec 11, 2023
dd0c430
Update pendulum
benjamc Dec 11, 2023
86e1737
Format
benjamc Dec 11, 2023
72b82e1
Update changelog.md
benjamc Dec 11, 2023
1a0f985
Fix context space
benjamc Dec 11, 2023
a35e087
Fix context space
benjamc Dec 11, 2023
39782de
Fix context space
benjamc Dec 11, 2023
10aa406
Fix context space
benjamc Dec 11, 2023
db292e5
Merge branch 'development' into gymnax
benjamc Dec 11, 2023
7df9d72
Remove warnings
benjamc Dec 11, 2023
134f161
Fix pre-commit
benjamc Dec 11, 2023
778e383
Fix context space
benjamc Dec 11, 2023
7c2e21e
Fix pre-commit
benjamc Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions carl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ def check_spec(spec_name: str) -> bool:
from carl.envs.rna import *

__all__ += envs.rna.__all__

gymnax_spec = iutil.find_spec("gymnax")
found = gymnax_spec is not None
if found:
from carl.envs.gymnax import *

__all__ += envs.gymnax.__all__
1 change: 0 additions & 1 deletion carl/envs/carl_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def get_observation_space(
context_feature_names=obs_context_feature_names,
as_dict=self.obs_context_as_dict,
)

obs_space = spaces.Dict(
{
"obs": self.base_observation_space,
Expand Down
3 changes: 0 additions & 3 deletions carl/envs/gymnasium/classic_control/carl_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class CARLPendulum(CARLGymnasiumEnv):
@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
"gravity", lower=-np.inf, upper=np.inf, default_value=8.0
),
"dt": UniformFloatContextFeature(
"dt", lower=0, upper=np.inf, default_value=0.05
),
Expand Down
15 changes: 15 additions & 0 deletions carl/envs/gymnax/__init__.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why call them Jax and not Gymnax? I bet there are other attempts at Jaxing them, Gymnax would be more explicit imo

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from carl.envs.gymnax.classic_control import (
CARLGymnaxAcrobot,
CARLGymnaxCartPole,
CARLGymnaxMountainCar,
CARLGymnaxMountainCarContinuous,
CARLGymnaxPendulum,
)

__all__ = [
"CARLGymnaxAcrobot",
"CARLGymnaxCartPole",
"CARLGymnaxMountainCar",
"CARLGymnaxMountainCarContinuous",
"CARLGymnaxPendulum",
]
84 changes: 84 additions & 0 deletions carl/envs/gymnax/carl_gymnax_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from typing import Any

import importlib

from gymnasium.core import Env

from carl.context.selection import AbstractSelector
from carl.envs.carl_env import CARLEnv
from carl.envs.gymnax.utils import make_gymnax_env
from carl.utils.types import Contexts


class CARLGymnaxEnv(CARLEnv):
env_name: str

def __init__(
self,
env: Env | None = None,
contexts: Contexts | None = None,
obs_context_features: list[str]
| None = None, # list the context features which should be added to the state
obs_context_as_dict: bool = True,
context_selector: AbstractSelector | type[AbstractSelector] | None = None,
context_selector_kwargs: dict = None,
**kwargs,
) -> None:
"""
CARL Gymnax Environment.

Parameters
----------

env : Env | None
Gymnasium environment, the default is None.
If None, instantiate the env with gymnasium's make function and
`self.env_name` which is defined in each child class.
contexts : Contexts | None, optional
Context set, by default None. If it is None, we build the
context set with the default context.
obs_context_features : list[str] | None, optional
Context features which should be included in the observation, by default None.
If they are None, add all context features.
context_selector: AbstractSelector | type[AbstractSelector] | None, optional
The context selector (class), after each reset selects a new context to use.
If None, use a round robin selector.
context_selector_kwargs : dict, optional
Optional keyword arguments for the context selector, by default None.
Only used when `context_selector` is not None.

Attributes
----------
env_name: str
The registered gymnax environment name.
"""
if env is None:
env = make_gymnax_env(env_name=self.env_name)

super().__init__(
env=env,
contexts=contexts,
obs_context_features=obs_context_features,
obs_context_as_dict=obs_context_as_dict,
context_selector=context_selector,
context_selector_kwargs=context_selector_kwargs,
**kwargs,
)

def __getattr__(self, name: str) -> Any:
if name in ["sys", "__getstate__"]:
return getattr(self.env._environment, name)
else:
return getattr(self, name)

def _update_context(self) -> None:
content = self.env.env_params.__dict__
content.update(self.context)
# We cannot directly set attributes of env_params because it is a frozen dataclass

# TODO Make this faster by preloading module?
self.env.env.env_params = getattr(
importlib.import_module(f"gymnax.environments.{self.module}"), "EnvParams"
)(**content)
15 changes: 15 additions & 0 deletions carl/envs/gymnax/classic_control/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from carl.envs.gymnax.classic_control.carl_gymnax_acrobot import CARLGymnaxAcrobot
from carl.envs.gymnax.classic_control.carl_gymnax_cartpole import CARLGymnaxCartPole
from carl.envs.gymnax.classic_control.carl_gymnax_mountaincar import (
CARLGymnaxMountainCar,
CARLGymnaxMountainCarContinuous,
)
from carl.envs.gymnax.classic_control.carl_gymnax_pendulum import CARLGymnaxPendulum

__all__ = [
"CARLGymnaxAcrobot",
"CARLGymnaxCartPole",
"CARLGymnaxMountainCar",
"CARLGymnaxMountainCarContinuous",
"CARLGymnaxPendulum",
]
53 changes: 53 additions & 0 deletions carl/envs/gymnax/classic_control/carl_gymnax_acrobot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import jax.numpy as jnp
import numpy as np

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv


class CARLGymnaxAcrobot(CARLGymnaxEnv):
env_name: str = "Acrobot-v1"
module: str = "classic_control.acrobot"

@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"link_length_1": UniformFloatContextFeature(
"link_length_1", lower=0.1, upper=10, default_value=1
), # Links can be shrunken and grown by a factor of 10
"link_length_2": UniformFloatContextFeature(
"link_length_2", lower=0.1, upper=10, default_value=1
), # Links can be shrunken and grown by a factor of 10
"link_mass_1": UniformFloatContextFeature(
"link_mass_1", lower=0.1, upper=10, default_value=1
), # Link mass can be shrunken and grown by a factor of 10
"link_mass_2": UniformFloatContextFeature(
"link_mass_2", lower=0.1, upper=10, default_value=1
), # Link mass can be shrunken and grown by a factor of 10
"link_com_pos_1": UniformFloatContextFeature(
"link_com_pos_1", lower=0, upper=1, default_value=0.5
), # Center of mass can move from one end to the other
"link_com_pos_2": UniformFloatContextFeature(
"link_com_pos_2", lower=0, upper=1, default_value=0.5
), # Center of mass can move from one end to the other
"link_moi": UniformFloatContextFeature(
"link_moi", lower=0.1, upper=10, default_value=1
), # Moments on inertia can be shrunken and grown by a factor of 10
"max_vel_1": UniformFloatContextFeature(
"max_vel_1",
lower=0.4 * jnp.pi,
upper=40 * jnp.pi,
default_value=4 * jnp.pi,
), # Velocity can vary by a factor of 10 in either direction
"max_vel_2": UniformFloatContextFeature(
"max_vel_2",
lower=0.9 * np.pi,
upper=90 * np.pi,
default_value=9 * np.pi,
), # Velocity can vary by a factor of 10 in either direction
"torque_noise_max": UniformFloatContextFeature(
"torque_noise_max", lower=-1, upper=1, default_value=0
), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme
}
45 changes: 45 additions & 0 deletions carl/envs/gymnax/classic_control/carl_gymnax_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import importlib

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv


class CARLGymnaxCartPole(CARLGymnaxEnv):
env_name: str = "CartPole-v1"
module: str = "classic_control.cartpole"

@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
"gravity", lower=0.01, upper=100, default_value=9.8
),
"masscart": UniformFloatContextFeature(
"masscart", lower=0.1, upper=10, default_value=1.0
),
"masspole": UniformFloatContextFeature(
"masspole", lower=0.01, upper=1, default_value=0.1
),
"length": UniformFloatContextFeature(
"length", lower=0.05, upper=5, default_value=0.5
),
"force_mag": UniformFloatContextFeature(
"force_mag", lower=1, upper=100, default_value=10.0
),
"tau": UniformFloatContextFeature(
"tau", lower=0.002, upper=0.2, default_value=0.02
),
}

def _update_context(self) -> None:
content = self.env.env_params.__dict__
content.update(self.context)
content["total_mass"] = content["masspole"] + content["masscart"]
content["polemass_length"] = content["masspole"] * content["length"]

# TODO Make this faster by preloading module?
self.env.env.env_params = getattr(
importlib.import_module(f"gymnax.environments.{self.module}"), "EnvParams"
)(**content)
54 changes: 54 additions & 0 deletions carl/envs/gymnax/classic_control/carl_gymnax_mountaincar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv


class CARLGymnaxMountainCar(CARLGymnaxEnv):
env_name: str = "MountainCar-v0"
module: str = "classic_control.mountain_car"

@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"max_speed": UniformFloatContextFeature(
"max_speed", lower=1e-3, upper=10, default_value=0.07
),
"goal_position": UniformFloatContextFeature(
"goal_position", lower=-2, upper=2, default_value=0.45
),
"goal_velocity": UniformFloatContextFeature(
"goal_velocity", lower=-10, upper=10, default_value=0
),
"force": UniformFloatContextFeature(
"force", lower=-10, upper=10, default_value=0.001
),
"gravity": UniformFloatContextFeature(
"gravity", lower=-10, upper=10, default_value=0.0025
),
}


class CARLGymnaxMountainCarContinuous(CARLGymnaxMountainCar):
env_name: str = "MountainCarContinuous-v0"
module: str = "classic_control.continuous_mountain_car"

@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"max_speed": UniformFloatContextFeature(
"max_speed", lower=1e-3, upper=10, default_value=0.07
),
"goal_position": UniformFloatContextFeature(
"goal_position", lower=-2, upper=2, default_value=0.45
),
"goal_velocity": UniformFloatContextFeature(
"goal_velocity", lower=-10, upper=10, default_value=0
),
"power": UniformFloatContextFeature(
"power", lower=1e-6, upper=10, default_value=0.001
),
"gravity": UniformFloatContextFeature(
"gravity", lower=-10, upper=10, default_value=0.0025
),
}
32 changes: 32 additions & 0 deletions carl/envs/gymnax/classic_control/carl_gymnax_pendulum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from carl.context.context_space import ContextFeature, UniformFloatContextFeature
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv


class CARLGymnaxPendulum(CARLGymnaxEnv):
env_name: str = "Pendulum-v1"
module: str = "classic_control.pendulum"

@staticmethod
def get_context_features() -> dict[str, ContextFeature]:
return {
"dt": UniformFloatContextFeature(
"dt", lower=0.001, upper=10, default_value=0.05
),
"g": UniformFloatContextFeature(
"g", lower=-100, upper=100, default_value=10
),
"m": UniformFloatContextFeature(
"m", lower=1e-6, upper=100, default_value=1
),
"l": UniformFloatContextFeature(
"l", lower=1e-6, upper=100, default_value=1
),
"max_speed": UniformFloatContextFeature(
"max_speed", lower=0.08, upper=80, default_value=8
),
"max_torque": UniformFloatContextFeature(
"max_torque", lower=0.02, upper=40, default_value=2
),
}
56 changes: 56 additions & 0 deletions carl/envs/gymnax/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from typing import Any

import gymnasium
import gymnasium.spaces
import gymnax
from gymnax.environments.environment import Environment, EnvParams
from gymnax.environments.spaces import Space, gymnax_space_to_gym_space
from gymnax.wrappers.gym import GymnaxToGymWrapper


# Although this converts to gym, the step API already is for gymnasium
class CustomGymnaxToGymnasiumWrapper(GymnaxToGymWrapper):
def __init__(
self, env: Environment, params: EnvParams | None = None, seed: int | None = None
):
super().__init__(env, params, seed)

self._observation_space = SpaceWrapper(
gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
)

@property
def env(self) -> Environment:
return self._env

@env.setter
def env(self, value: Environment) -> None:
self._env = value

@property
def observation_space(self) -> gymnasium.Space:
return self._observation_space

@observation_space.setter
def observation_space(self, value: Space) -> None:
self._observation_space = value


class SpaceWrapper(gymnasium.Space):
def __init__(self, space):
self.space = space

def __getattr__(self, __name: str) -> Any:
return self.space.__getattr__(__name=__name)


def make_gymnax_env(env_name: str) -> gymnasium.Env:
# Make gymnax env
env, env_params = gymnax.make(env_id=env_name)

# Convert gymnax to gymnasium API
env = CustomGymnaxToGymnasiumWrapper(env=env, params=env_params)

return env
3 changes: 3 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# 1.1.0
- Add gymnax classic control environments (#90)

# 1.0.0
Major overhaul of the CARL environment
- Contexts are stored in each environment's class
Expand Down
Loading
Loading