-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Classic Control Gymnax #90
Open
benjamc
wants to merge
39
commits into
development
Choose a base branch
from
gymnax
base: development
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
6802e04
Add gymnax folder
10b1e85
Update requirements
ae2caf4
add CARLJaxPendulumEnv
Arman717 694241e
add init
Arman717 f1ec7b3
add CARLJaxAcrobotEnv
Arman717 3ed010d
remove comment
Arman717 9fd4ea6
add CARLJaxCartPoleEnv
Arman717 5b6eee5
add CARLJaxMountainCarEnv
Arman717 4fa6f98
add CARLJaxMountainCarContinuousEnv
Arman717 68048f5
fix setup
Arman717 666a323
Rename file
ffc2bb9
Add imports
9a3d416
add modified version for GymnaxToGymWrapper
Arman717 5874515
make precommit
Arman717 92458e0
make precommit
Arman717 9de0e29
Fix gymnax
0bba185
Fix pre-commit
289d324
Merge branch 'development' into gymnax
benjamc 97ae965
Move files
benjamc c0c9620
Rename envs and imports
benjamc 2406616
Rename files
benjamc 63cba37
Fix env comapts
benjamc 689a2c8
Reformat
benjamc 7ec0dfc
Smarter update context
benjamc adde740
Update cartpole
benjamc 52bd501
Update mountaincar
benjamc 9874e52
Fix context space
benjamc dd0c430
Update pendulum
benjamc 86e1737
Format
benjamc 72b82e1
Update changelog.md
benjamc 1a0f985
Fix context space
benjamc a35e087
Fix context space
benjamc 39782de
Fix context space
benjamc 10aa406
Fix context space
benjamc db292e5
Merge branch 'development' into gymnax
benjamc 7df9d72
Remove warnings
benjamc 134f161
Fix pre-commit
benjamc 778e383
Fix context space
benjamc 7c2e21e
Fix pre-commit
benjamc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from carl.envs.gymnax.classic_control import ( | ||
CARLGymnaxAcrobot, | ||
CARLGymnaxCartPole, | ||
CARLGymnaxMountainCar, | ||
CARLGymnaxMountainCarContinuous, | ||
CARLGymnaxPendulum, | ||
) | ||
|
||
__all__ = [ | ||
"CARLGymnaxAcrobot", | ||
"CARLGymnaxCartPole", | ||
"CARLGymnaxMountainCar", | ||
"CARLGymnaxMountainCarContinuous", | ||
"CARLGymnaxPendulum", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import importlib | ||
|
||
from gymnasium.core import Env | ||
|
||
from carl.context.selection import AbstractSelector | ||
from carl.envs.carl_env import CARLEnv | ||
from carl.envs.gymnax.utils import make_gymnax_env | ||
from carl.utils.types import Contexts | ||
|
||
|
||
class CARLGymnaxEnv(CARLEnv): | ||
env_name: str | ||
|
||
def __init__( | ||
self, | ||
env: Env | None = None, | ||
contexts: Contexts | None = None, | ||
obs_context_features: list[str] | ||
| None = None, # list the context features which should be added to the state | ||
obs_context_as_dict: bool = True, | ||
context_selector: AbstractSelector | type[AbstractSelector] | None = None, | ||
context_selector_kwargs: dict = None, | ||
**kwargs, | ||
) -> None: | ||
""" | ||
CARL Gymnax Environment. | ||
|
||
Parameters | ||
---------- | ||
|
||
env : Env | None | ||
Gymnasium environment, the default is None. | ||
If None, instantiate the env with gymnasium's make function and | ||
`self.env_name` which is defined in each child class. | ||
contexts : Contexts | None, optional | ||
Context set, by default None. If it is None, we build the | ||
context set with the default context. | ||
obs_context_features : list[str] | None, optional | ||
Context features which should be included in the observation, by default None. | ||
If they are None, add all context features. | ||
context_selector: AbstractSelector | type[AbstractSelector] | None, optional | ||
The context selector (class), after each reset selects a new context to use. | ||
If None, use a round robin selector. | ||
context_selector_kwargs : dict, optional | ||
Optional keyword arguments for the context selector, by default None. | ||
Only used when `context_selector` is not None. | ||
|
||
Attributes | ||
---------- | ||
env_name: str | ||
The registered gymnax environment name. | ||
""" | ||
if env is None: | ||
env = make_gymnax_env(env_name=self.env_name) | ||
|
||
super().__init__( | ||
env=env, | ||
contexts=contexts, | ||
obs_context_features=obs_context_features, | ||
obs_context_as_dict=obs_context_as_dict, | ||
context_selector=context_selector, | ||
context_selector_kwargs=context_selector_kwargs, | ||
**kwargs, | ||
) | ||
|
||
def __getattr__(self, name: str) -> Any: | ||
if name in ["sys", "__getstate__"]: | ||
return getattr(self.env._environment, name) | ||
else: | ||
return getattr(self, name) | ||
|
||
def _update_context(self) -> None: | ||
content = self.env.env_params.__dict__ | ||
content.update(self.context) | ||
# We cannot directly set attributes of env_params because it is a frozen dataclass | ||
|
||
# TODO Make this faster by preloading module? | ||
self.env.env.env_params = getattr( | ||
importlib.import_module(f"gymnax.environments.{self.module}"), "EnvParams" | ||
)(**content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from carl.envs.gymnax.classic_control.carl_gymnax_acrobot import CARLGymnaxAcrobot | ||
from carl.envs.gymnax.classic_control.carl_gymnax_cartpole import CARLGymnaxCartPole | ||
from carl.envs.gymnax.classic_control.carl_gymnax_mountaincar import ( | ||
CARLGymnaxMountainCar, | ||
CARLGymnaxMountainCarContinuous, | ||
) | ||
from carl.envs.gymnax.classic_control.carl_gymnax_pendulum import CARLGymnaxPendulum | ||
|
||
__all__ = [ | ||
"CARLGymnaxAcrobot", | ||
"CARLGymnaxCartPole", | ||
"CARLGymnaxMountainCar", | ||
"CARLGymnaxMountainCarContinuous", | ||
"CARLGymnaxPendulum", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from __future__ import annotations | ||
|
||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from carl.context.context_space import ContextFeature, UniformFloatContextFeature | ||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
|
||
|
||
class CARLGymnaxAcrobot(CARLGymnaxEnv): | ||
env_name: str = "Acrobot-v1" | ||
module: str = "classic_control.acrobot" | ||
|
||
@staticmethod | ||
def get_context_features() -> dict[str, ContextFeature]: | ||
return { | ||
"link_length_1": UniformFloatContextFeature( | ||
"link_length_1", lower=0.1, upper=10, default_value=1 | ||
), # Links can be shrunken and grown by a factor of 10 | ||
"link_length_2": UniformFloatContextFeature( | ||
"link_length_2", lower=0.1, upper=10, default_value=1 | ||
), # Links can be shrunken and grown by a factor of 10 | ||
"link_mass_1": UniformFloatContextFeature( | ||
"link_mass_1", lower=0.1, upper=10, default_value=1 | ||
), # Link mass can be shrunken and grown by a factor of 10 | ||
"link_mass_2": UniformFloatContextFeature( | ||
"link_mass_2", lower=0.1, upper=10, default_value=1 | ||
), # Link mass can be shrunken and grown by a factor of 10 | ||
"link_com_pos_1": UniformFloatContextFeature( | ||
"link_com_pos_1", lower=0, upper=1, default_value=0.5 | ||
), # Center of mass can move from one end to the other | ||
"link_com_pos_2": UniformFloatContextFeature( | ||
"link_com_pos_2", lower=0, upper=1, default_value=0.5 | ||
), # Center of mass can move from one end to the other | ||
"link_moi": UniformFloatContextFeature( | ||
"link_moi", lower=0.1, upper=10, default_value=1 | ||
), # Moments on inertia can be shrunken and grown by a factor of 10 | ||
"max_vel_1": UniformFloatContextFeature( | ||
"max_vel_1", | ||
lower=0.4 * jnp.pi, | ||
upper=40 * jnp.pi, | ||
default_value=4 * jnp.pi, | ||
), # Velocity can vary by a factor of 10 in either direction | ||
"max_vel_2": UniformFloatContextFeature( | ||
"max_vel_2", | ||
lower=0.9 * np.pi, | ||
upper=90 * np.pi, | ||
default_value=9 * np.pi, | ||
), # Velocity can vary by a factor of 10 in either direction | ||
"torque_noise_max": UniformFloatContextFeature( | ||
"torque_noise_max", lower=-1, upper=1, default_value=0 | ||
), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from __future__ import annotations | ||
|
||
import importlib | ||
|
||
from carl.context.context_space import ContextFeature, UniformFloatContextFeature | ||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
|
||
|
||
class CARLGymnaxCartPole(CARLGymnaxEnv): | ||
env_name: str = "CartPole-v1" | ||
module: str = "classic_control.cartpole" | ||
|
||
@staticmethod | ||
def get_context_features() -> dict[str, ContextFeature]: | ||
return { | ||
"gravity": UniformFloatContextFeature( | ||
"gravity", lower=0.01, upper=100, default_value=9.8 | ||
), | ||
"masscart": UniformFloatContextFeature( | ||
"masscart", lower=0.1, upper=10, default_value=1.0 | ||
), | ||
"masspole": UniformFloatContextFeature( | ||
"masspole", lower=0.01, upper=1, default_value=0.1 | ||
), | ||
"length": UniformFloatContextFeature( | ||
"length", lower=0.05, upper=5, default_value=0.5 | ||
), | ||
"force_mag": UniformFloatContextFeature( | ||
"force_mag", lower=1, upper=100, default_value=10.0 | ||
), | ||
"tau": UniformFloatContextFeature( | ||
"tau", lower=0.002, upper=0.2, default_value=0.02 | ||
), | ||
} | ||
|
||
def _update_context(self) -> None: | ||
content = self.env.env_params.__dict__ | ||
content.update(self.context) | ||
content["total_mass"] = content["masspole"] + content["masscart"] | ||
content["polemass_length"] = content["masspole"] * content["length"] | ||
|
||
# TODO Make this faster by preloading module? | ||
self.env.env.env_params = getattr( | ||
importlib.import_module(f"gymnax.environments.{self.module}"), "EnvParams" | ||
)(**content) |
54 changes: 54 additions & 0 deletions
54
carl/envs/gymnax/classic_control/carl_gymnax_mountaincar.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
from carl.context.context_space import ContextFeature, UniformFloatContextFeature | ||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
|
||
|
||
class CARLGymnaxMountainCar(CARLGymnaxEnv): | ||
env_name: str = "MountainCar-v0" | ||
module: str = "classic_control.mountain_car" | ||
|
||
@staticmethod | ||
def get_context_features() -> dict[str, ContextFeature]: | ||
return { | ||
"max_speed": UniformFloatContextFeature( | ||
"max_speed", lower=1e-3, upper=10, default_value=0.07 | ||
), | ||
"goal_position": UniformFloatContextFeature( | ||
"goal_position", lower=-2, upper=2, default_value=0.45 | ||
), | ||
"goal_velocity": UniformFloatContextFeature( | ||
"goal_velocity", lower=-10, upper=10, default_value=0 | ||
), | ||
"force": UniformFloatContextFeature( | ||
"force", lower=-10, upper=10, default_value=0.001 | ||
), | ||
"gravity": UniformFloatContextFeature( | ||
"gravity", lower=-10, upper=10, default_value=0.0025 | ||
), | ||
} | ||
|
||
|
||
class CARLGymnaxMountainCarContinuous(CARLGymnaxMountainCar): | ||
env_name: str = "MountainCarContinuous-v0" | ||
module: str = "classic_control.continuous_mountain_car" | ||
|
||
@staticmethod | ||
def get_context_features() -> dict[str, ContextFeature]: | ||
return { | ||
"max_speed": UniformFloatContextFeature( | ||
"max_speed", lower=1e-3, upper=10, default_value=0.07 | ||
), | ||
"goal_position": UniformFloatContextFeature( | ||
"goal_position", lower=-2, upper=2, default_value=0.45 | ||
), | ||
"goal_velocity": UniformFloatContextFeature( | ||
"goal_velocity", lower=-10, upper=10, default_value=0 | ||
), | ||
"power": UniformFloatContextFeature( | ||
"power", lower=1e-6, upper=10, default_value=0.001 | ||
), | ||
"gravity": UniformFloatContextFeature( | ||
"gravity", lower=-10, upper=10, default_value=0.0025 | ||
), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from __future__ import annotations | ||
|
||
from carl.context.context_space import ContextFeature, UniformFloatContextFeature | ||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
|
||
|
||
class CARLGymnaxPendulum(CARLGymnaxEnv): | ||
env_name: str = "Pendulum-v1" | ||
module: str = "classic_control.pendulum" | ||
|
||
@staticmethod | ||
def get_context_features() -> dict[str, ContextFeature]: | ||
return { | ||
"dt": UniformFloatContextFeature( | ||
"dt", lower=0.001, upper=10, default_value=0.05 | ||
), | ||
"g": UniformFloatContextFeature( | ||
"g", lower=-100, upper=100, default_value=10 | ||
), | ||
"m": UniformFloatContextFeature( | ||
"m", lower=1e-6, upper=100, default_value=1 | ||
), | ||
"l": UniformFloatContextFeature( | ||
"l", lower=1e-6, upper=100, default_value=1 | ||
), | ||
"max_speed": UniformFloatContextFeature( | ||
"max_speed", lower=0.08, upper=80, default_value=8 | ||
), | ||
"max_torque": UniformFloatContextFeature( | ||
"max_torque", lower=0.02, upper=40, default_value=2 | ||
), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
import gymnasium | ||
import gymnasium.spaces | ||
import gymnax | ||
from gymnax.environments.environment import Environment, EnvParams | ||
from gymnax.environments.spaces import Space, gymnax_space_to_gym_space | ||
from gymnax.wrappers.gym import GymnaxToGymWrapper | ||
|
||
|
||
# Although this converts to gym, the step API already is for gymnasium | ||
class CustomGymnaxToGymnasiumWrapper(GymnaxToGymWrapper): | ||
def __init__( | ||
self, env: Environment, params: EnvParams | None = None, seed: int | None = None | ||
): | ||
super().__init__(env, params, seed) | ||
|
||
self._observation_space = SpaceWrapper( | ||
gymnax_space_to_gym_space(self._env.observation_space(self.env_params)) | ||
) | ||
|
||
@property | ||
def env(self) -> Environment: | ||
return self._env | ||
|
||
@env.setter | ||
def env(self, value: Environment) -> None: | ||
self._env = value | ||
|
||
@property | ||
def observation_space(self) -> gymnasium.Space: | ||
return self._observation_space | ||
|
||
@observation_space.setter | ||
def observation_space(self, value: Space) -> None: | ||
self._observation_space = value | ||
|
||
|
||
class SpaceWrapper(gymnasium.Space): | ||
def __init__(self, space): | ||
self.space = space | ||
|
||
def __getattr__(self, __name: str) -> Any: | ||
return self.space.__getattr__(__name=__name) | ||
|
||
|
||
def make_gymnax_env(env_name: str) -> gymnasium.Env: | ||
# Make gymnax env | ||
env, env_params = gymnax.make(env_id=env_name) | ||
|
||
# Convert gymnax to gymnasium API | ||
env = CustomGymnaxToGymnasiumWrapper(env=env, params=env_params) | ||
|
||
return env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why call them Jax and not Gymnax? I bet there are other attempts at Jaxing them, Gymnax would be more explicit imo