-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Classic Control Gymnax #90
Open
benjamc
wants to merge
39
commits into
development
Choose a base branch
from
gymnax
base: development
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 17 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
6802e04
Add gymnax folder
10b1e85
Update requirements
ae2caf4
add CARLJaxPendulumEnv
Arman717 694241e
add init
Arman717 f1ec7b3
add CARLJaxAcrobotEnv
Arman717 3ed010d
remove comment
Arman717 9fd4ea6
add CARLJaxCartPoleEnv
Arman717 5b6eee5
add CARLJaxMountainCarEnv
Arman717 4fa6f98
add CARLJaxMountainCarContinuousEnv
Arman717 68048f5
fix setup
Arman717 666a323
Rename file
ffc2bb9
Add imports
9a3d416
add modified version for GymnaxToGymWrapper
Arman717 5874515
make precommit
Arman717 92458e0
make precommit
Arman717 9de0e29
Fix gymnax
0bba185
Fix pre-commit
289d324
Merge branch 'development' into gymnax
benjamc 97ae965
Move files
benjamc c0c9620
Rename envs and imports
benjamc 2406616
Rename files
benjamc 63cba37
Fix env comapts
benjamc 689a2c8
Reformat
benjamc 7ec0dfc
Smarter update context
benjamc adde740
Update cartpole
benjamc 52bd501
Update mountaincar
benjamc 9874e52
Fix context space
benjamc dd0c430
Update pendulum
benjamc 86e1737
Format
benjamc 72b82e1
Update changelog.md
benjamc 1a0f985
Fix context space
benjamc a35e087
Fix context space
benjamc 39782de
Fix context space
benjamc 10aa406
Fix context space
benjamc db292e5
Merge branch 'development' into gymnax
benjamc 7df9d72
Remove warnings
benjamc 134f161
Fix pre-commit
benjamc 778e383
Fix context space
benjamc 7c2e21e
Fix pre-commit
benjamc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# flake8: noqa: F401 | ||
from carl.envs.gymnax.carl_jax_acrobot import CONTEXT_BOUNDS as CARLJaxAcrobotEnv_bounds | ||
from carl.envs.gymnax.carl_jax_acrobot import ( | ||
DEFAULT_CONTEXT as CARLJaxAcrobotEnv_defaults, | ||
) | ||
from carl.envs.gymnax.carl_jax_acrobot import CARLJaxAcrobotEnv | ||
from carl.envs.gymnax.carl_jax_cartpole import ( | ||
CONTEXT_BOUNDS as CARLJaxCartPoleEnv_bounds, | ||
) | ||
from carl.envs.gymnax.carl_jax_cartpole import ( | ||
DEFAULT_CONTEXT as CARLJaxCartPoleEnv_defaults, | ||
) | ||
from carl.envs.gymnax.carl_jax_cartpole import CARLJaxCartPoleEnv | ||
from carl.envs.gymnax.carl_jax_mountaincar import ( | ||
CONTEXT_BOUNDS as CARLJaxMountainCarContinuousEnv_bounds, | ||
) | ||
from carl.envs.gymnax.carl_jax_mountaincar import ( | ||
CONTEXT_BOUNDS as CARLJaxMountainCarEnv_bounds, | ||
) | ||
from carl.envs.gymnax.carl_jax_mountaincar import ( | ||
DEFAULT_CONTEXT as CARLJaxMountainCarContinuousEnv_defaults, | ||
) | ||
from carl.envs.gymnax.carl_jax_mountaincar import ( | ||
DEFAULT_CONTEXT as CARLJaxMountainCarEnv_defaults, | ||
) | ||
from carl.envs.gymnax.carl_jax_mountaincar import ( | ||
CARLJaxMountainCarContinuousEnv, | ||
CARLJaxMountainCarEnv, | ||
) | ||
from carl.envs.gymnax.carl_jax_pendulum import ( | ||
CONTEXT_BOUNDS as CARLJaxPendulumEnv_bounds, | ||
) | ||
from carl.envs.gymnax.carl_jax_pendulum import ( | ||
DEFAULT_CONTEXT as CARLJaxPendulumEnv_defaults, | ||
) | ||
from carl.envs.gymnax.carl_jax_pendulum import CARLJaxPendulumEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import gymnasium | ||
|
||
from CARL.carl.envs.gymnax.utils import make_gymnax_env | ||
from carl.context.selection import AbstractSelector | ||
from carl.envs.carl_env import CARLEnv | ||
from carl.utils.trial_logger import TrialLogger | ||
from carl.utils.types import Context, Contexts | ||
|
||
|
||
class CARLGymnaxEnv(CARLEnv): | ||
env_name: str | ||
DEFAULT_CONTEXT: Context | ||
max_episode_steps: int | ||
|
||
def __init__( | ||
self, | ||
env: gymnasium.Env | None = None, | ||
contexts: Contexts = {}, | ||
hide_context: bool = True, | ||
add_gaussian_noise_to_context: bool = False, | ||
gaussian_noise_std_percentage: float = 0.01, | ||
logger: Optional[TrialLogger] = None, | ||
scale_context_features: str = "no", | ||
default_context: Optional[Context] = None, | ||
state_context_features: Optional[List[str]] = None, | ||
context_mask: Optional[List[str]] = None, | ||
dict_observation_space: bool = False, | ||
context_selector: Optional[ | ||
Union[AbstractSelector, type[AbstractSelector]] | ||
] = None, | ||
context_selector_kwargs: Optional[Dict] = None, | ||
): | ||
""" | ||
Max torque is not a context feature because it changes the action space. | ||
|
||
Parameters | ||
---------- | ||
env | ||
contexts | ||
instance_mode | ||
hide_context | ||
add_gaussian_noise_to_context | ||
gaussian_noise_std_percentage | ||
""" | ||
if env is None: | ||
env = make_gymnax_env(env_name=self.env_name) | ||
|
||
if not contexts: | ||
contexts = {0: self.DEFAULT_CONTEXT} | ||
|
||
if not default_context: | ||
default_context = self.DEFAULT_CONTEXT | ||
|
||
super().__init__( | ||
env=env, | ||
contexts=contexts, | ||
hide_context=hide_context, | ||
add_gaussian_noise_to_context=add_gaussian_noise_to_context, | ||
gaussian_noise_std_percentage=gaussian_noise_std_percentage, | ||
logger=logger, | ||
scale_context_features=scale_context_features, | ||
default_context=default_context, | ||
max_episode_length=self.max_episode_steps, | ||
state_context_features=state_context_features, | ||
dict_observation_space=dict_observation_space, | ||
context_selector=context_selector, | ||
context_selector_kwargs=context_selector_kwargs, | ||
context_mask=context_mask, | ||
) | ||
self.whitelist_gaussian_noise = list( | ||
self.DEFAULT_CONTEXT.keys() | ||
) # allow to augment all values | ||
|
||
def _update_context(self) -> None: | ||
raise NotImplementedError | ||
|
||
def __getattr__(self, name: str) -> Any: | ||
if name in ["sys", "__getstate__"]: | ||
return getattr(self.env._environment, name) | ||
else: | ||
return getattr(self, name) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Env params is frozen for a reason, but I guess this runs, right? Pretty sure we can't jit it, but we have this issue with brax as well... at some point we might want to brainstorm if we can somehown avoid replacing env parts to change context. For now it's alright, though. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from __future__ import annotations | ||
|
||
import gymnax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
from carl.utils.types import Context | ||
|
||
DEFAULT_CONTEXT = { | ||
"link_length_1": 1, | ||
"link_length_2": 1, | ||
"link_mass_1": 1, | ||
"link_mass_2": 1, | ||
"link_com_pos_1": 0.5, | ||
"link_com_pos_2": 0.5, | ||
"link_moi": 1, | ||
"max_vel_1": 4 * jnp.pi, | ||
"max_vel_2": 9 * jnp.pi, | ||
"torque_noise_max": 0.0, | ||
"max_steps_in_episode": 500, | ||
} | ||
|
||
CONTEXT_BOUNDS = { | ||
"link_length_1": ( | ||
0.1, | ||
10, | ||
float, | ||
), # Links can be shrunken and grown by a factor of 10 | ||
"link_length_2": (0.1, 10, float), | ||
"link_mass_1": ( | ||
0.1, | ||
10, | ||
float, | ||
), # Link mass can be shrunken and grown by a factor of 10 | ||
"link_mass_2": (0.1, 10, float), | ||
"link_com_pos_1": ( | ||
0, | ||
1, | ||
float, | ||
), # Center of mass can move from one end to the other | ||
"link_com_pos_2": (0, 1, float), | ||
"link_moi": ( | ||
0.1, | ||
10, | ||
float, | ||
), # Moments on inertia can be shrunken and grown by a factor of 10 | ||
"max_vel_1": ( | ||
0.4 * np.pi, | ||
40 * np.pi, | ||
float, | ||
), # Velocity can vary by a factor of 10 in either direction | ||
"max_vel_2": (0.9 * np.pi, 90 * np.pi, float), | ||
"torque_noise_max": ( | ||
-1.0, | ||
1.0, | ||
float, | ||
), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme | ||
"max_steps_in_episode": (1, jnp.inf, int), | ||
} | ||
|
||
|
||
class CARLJaxAcrobotEnv(CARLGymnaxEnv): | ||
env_name: str = "Acrobot-v1" | ||
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"]) | ||
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT | ||
|
||
def _update_context(self) -> None: | ||
content = self.env.env.env_params.__dict__ | ||
content.update(self.context) | ||
# We cannot directly set attributes of env_params because it is a frozen dataclass | ||
self.env.env.env_params = gymnax.environments.classic_control.acrobot.EnvParams( | ||
**content | ||
) | ||
|
||
high = jnp.array( | ||
[ | ||
1.0, | ||
1.0, | ||
1.0, | ||
1.0, | ||
self.env.env.env_params.max_vel_1, | ||
self.env.env.env_params.max_vel_2, | ||
], | ||
dtype=jnp.float32, | ||
) | ||
low = -high | ||
self.build_observation_space(low, high, CONTEXT_BOUNDS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
import gymnax | ||
import jax.numpy as jnp | ||
|
||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
from carl.utils.types import Context | ||
|
||
DEFAULT_CONTEXT = { | ||
"gravity": 9.8, | ||
"masscart": 1.0, | ||
"masspole": 0.1, | ||
"length": 0.5, | ||
"force_mag": 10.0, | ||
"tau": 0.02, | ||
"polemass_length": None, | ||
"total_mass": None, | ||
"max_steps_in_episode": 500, | ||
} | ||
|
||
CONTEXT_BOUNDS = { | ||
"gravity": (5.0, 15.0, float), | ||
"masscart": (0.5, 2.0, float), | ||
"masspole": (0.05, 0.2, float), | ||
"length": (0.25, 1.0, float), | ||
"force_mag": (5.0, 15.0, float), | ||
"tau": (0.01, 0.05, float), | ||
"polemass_length": (0, jnp.inf, float), | ||
"total_mass": (0, jnp.inf, float), | ||
"max_steps_in_episode": (1, jnp.inf, int), | ||
} | ||
|
||
|
||
class CARLJaxCartPoleEnv(CARLGymnaxEnv): | ||
env_name: str = "CartPole-v1" | ||
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"]) # type: ignore[arg-type] | ||
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT | ||
|
||
def _update_context(self) -> None: | ||
self.context["polemass_length"] = ( | ||
self.context["masspole"] * self.context["length"] | ||
) | ||
self.context["total_mass"] = self.context["masscart"] + self.context["masspole"] | ||
|
||
self.env.env.env_params = ( | ||
gymnax.environments.classic_control.cartpole.EnvParams(**self.context) | ||
) | ||
|
||
high = jnp.array( | ||
[ | ||
self.env.env.env_params.x_threshold * 2, | ||
jnp.finfo(jnp.float32).max, | ||
self.env.env.env_params.theta_threshold_radians * 2, | ||
jnp.finfo(jnp.float32).max, | ||
], | ||
dtype=jnp.float32, | ||
) | ||
low = -high | ||
self.build_observation_space(low, high, CONTEXT_BOUNDS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
import gymnax | ||
import jax.numpy as jnp | ||
|
||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
from carl.utils.types import Context | ||
|
||
DEFAULT_CONTEXT = { | ||
"min_position": -1.2, | ||
"max_position": 0.6, | ||
"max_speed": 0.07, | ||
"goal_position": 0.5, | ||
"goal_velocity": 0, | ||
"force": 0.001, | ||
"gravity": 0.0025, | ||
"max_steps_in_episode": 200, | ||
} | ||
|
||
CONTEXT_BOUNDS = { | ||
"min_position": (-jnp.inf, jnp.inf, float), | ||
"max_position": (-jnp.inf, jnp.inf, float), | ||
"max_speed": (0, jnp.inf, float), | ||
"goal_position": (-jnp.inf, jnp.inf, float), | ||
"goal_velocity": (-jnp.inf, jnp.inf, float), | ||
"force": (-jnp.inf, jnp.inf, float), | ||
"gravity": (0, jnp.inf, float), | ||
"max_steps_in_episode": (1, jnp.inf, int), | ||
} | ||
|
||
|
||
class CARLJaxMountainCarEnv(CARLGymnaxEnv): | ||
env_name: str = "MountainCar-v0" | ||
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"]) | ||
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT | ||
|
||
def _update_context(self) -> None: | ||
self.env.env.env_params = ( | ||
gymnax.environments.classic_control.mountain_car.EnvParams(**self.context) | ||
) | ||
|
||
self.low = jnp.array( | ||
[self.env.env.env_params.min_position, -self.env.env.env_params.max_speed], | ||
dtype=jnp.float32, | ||
).squeeze() | ||
self.high = jnp.array( | ||
[self.env.env.env_params.max_position, self.env.env.env_params.max_speed], | ||
dtype=jnp.float32, | ||
).squeeze() | ||
|
||
self.build_observation_space(self.low, self.high, CONTEXT_BOUNDS) | ||
|
||
|
||
class CARLJaxMountainCarContinuousEnv(CARLJaxMountainCarEnv): | ||
env_name: str = "MountainCarContinuous-v0" | ||
max_episode_steps: int = 999 | ||
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from __future__ import annotations | ||
|
||
import jax.numpy as jnp | ||
from gymnax.environments.classic_control.pendulum import EnvParams | ||
|
||
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv | ||
from carl.utils.types import Context | ||
|
||
DEFAULT_CONTEXT = { | ||
"max_speed": 8.0, | ||
"max_torque": 2.0, | ||
"dt": 0.05, | ||
"g": 10.0, | ||
"m": 1.0, | ||
"l": 1.0, | ||
"max_steps_in_episode": 200, | ||
} | ||
|
||
CONTEXT_BOUNDS = { | ||
"max_speed": (-jnp.inf, jnp.inf, float), | ||
"max_torque": (-jnp.inf, jnp.inf, float), | ||
"dt": (0, jnp.inf, float), | ||
"g": (0, jnp.inf, float), | ||
"m": (1e-6, jnp.inf, float), | ||
"l": (1e-6, jnp.inf, float), | ||
"max_steps_in_episode": (1, jnp.inf, int), | ||
} | ||
|
||
|
||
class CARLJaxPendulumEnv(CARLGymnaxEnv): | ||
env_name: str = "Pendulum-v1" | ||
max_episode_steps: int = int(DEFAULT_CONTEXT["max_steps_in_episode"]) | ||
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT | ||
|
||
def _update_context(self) -> None: | ||
self.env.env.env_params = EnvParams(**self.context) | ||
|
||
high = jnp.array( | ||
[1.0, 1.0, self.env.env.env_params.max_speed], dtype=jnp.float32 | ||
) | ||
self.build_observation_space(-high, high, CONTEXT_BOUNDS) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why call them Jax and not Gymnax? I bet there are other attempts at Jaxing them, Gymnax would be more explicit imo