Skip to content

Commit

Permalink
fix: gravity limits inverted
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Jan 9, 2024
1 parent 897f373 commit afbd479
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 12 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ install-dev:
$(PIP) install -e ".[dev, docs]"
pre-commit install

install:
$(PIP) install -e .

check-black:
$(BLACK) carl test --check || :

Expand Down
35 changes: 28 additions & 7 deletions carl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import warnings

# Classic control is in gym and thus necessary for the base version to run
from carl import envs
from carl.envs.gymnasium import *

__all__ = envs.gymnasium.__all__
__all__ = [
"CARLAcrobot",
"CARLCartPole",
"CARLMountainCar",
"CARLMountainCarContinuous",
"CARLPendulum",
]


def check_spec(spec_name: str) -> bool:
Expand Down Expand Up @@ -39,28 +44,44 @@ def check_spec(spec_name: str) -> bool:
if found:
from carl.envs.gymnasium.box2d import *

__all__ += envs.gymnasium.box2d.__all__
__all__ += ["CARLBipedalWalker", "CARLLunarLander", "CARLVehicleRacing"]

found = check_spec("brax")
if found:
from carl.envs.brax import *

__all__ += envs.brax.__all__
__all__ += [
"CARLBraxAnt",
"CARLBraxHalfcheetah",
"CARLBraxHopper",
"CARLBraxHumanoid",
"CARLBraxHumanoidStandup",
"CARLBraxInvertedDoublePendulum",
"CARLBraxInvertedPendulum",
"CARLBraxPusher",
"CARLBraxReacher",
"CARLBraxWalker2d",
]

found = check_spec("py4j")
if found:
from carl.envs.mario import *

__all__ += envs.mario.__all__
__all__ += ["CARLMarioEnv"]

found = check_spec("dm_control")
if found:
from carl.envs.dmc import *

__all__ += envs.dmc.__all__
__all__ += [
"CARLDmcFingerEnv",
"CARLDmcFishEnv",
"CARLDmcQuadrupedEnv",
"CARLDmcWalkerEnv",
]

found = check_spec("distance")
if found:
from carl.envs.rna import *

__all__ += envs.rna.__all__
__all__ += ["CARLRnaDesignEnv"]
2 changes: 1 addition & 1 deletion carl/envs/dmc/carl_dm_finger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CARLDmcFingerEnv(CARLDmcEnv):
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
"gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
"gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
Expand Down
2 changes: 1 addition & 1 deletion carl/envs/dmc/carl_dm_fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CARLDmcFishEnv(CARLDmcEnv):
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
"gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
"gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
Expand Down
2 changes: 1 addition & 1 deletion carl/envs/dmc/carl_dm_quadruped.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CARLDmcQuadrupedEnv(CARLDmcEnv):
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
"gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
"gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
Expand Down
2 changes: 1 addition & 1 deletion carl/envs/dmc/carl_dm_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CARLDmcWalkerEnv(CARLDmcEnv):
def get_context_features() -> dict[str, ContextFeature]:
return {
"gravity": UniformFloatContextFeature(
"gravity", lower=-np.inf, upper=-0.1, default_value=-9.81
"gravity", lower=0.1, upper=np.inf, default_value=9.81
),
"friction_torsional": UniformFloatContextFeature(
"friction_torsional", lower=0, upper=np.inf, default_value=1.0
Expand Down
6 changes: 5 additions & 1 deletion carl/envs/dmc/dmc_tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,21 @@ def adapt_context(xml_string: bytes, context: Context) -> bytes:

# find option settings and override them if they exist, otherwise create new option
option = mjcf.find(".//option")
import logging

if option is None:
option = etree.Element("option")
mjcf.append(option)

if "gravity" in context:
gravity = option.get("gravity")
logging.info(gravity)
if gravity is not None:
g = gravity.split(" ")
gravity = " ".join([g[0], g[1], str(-context["gravity"])])
else:
gravity = " ".join(["0", "0", str(-context["gravity"])])
gravity = " ".join(["0", "0", f"{str(-context['gravity'])}"])
logging.info(gravity)
option.set("gravity", gravity)

if "wind_x" in context and "wind_y" in context and "wind_z" in context:
Expand Down
37 changes: 37 additions & 0 deletions carl/envs/gymnasium/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# flake8: noqa: F401
# Modular imports
import importlib.util as iutil
import warnings

from carl.envs.gymnasium.classic_control import (
CARLAcrobot,
CARLCartPole,
Expand All @@ -13,3 +18,35 @@
"CARLMountainCarContinuous",
"CARLPendulum",
]


def check_spec(spec_name: str) -> bool:
"""Check if the spec is installed
Parameters
----------
spec_name : str
Name of package that is necessary for the environment suite.
Returns
-------
bool
Whether the spec was found.
"""
spec = iutil.find_spec(spec_name)
found = spec is not None
if not found:
with warnings.catch_warnings():
warnings.simplefilter("once")
warnings.warn(
f"Module {spec_name} not found. If you want to use these environments, please follow the installation guide."
)
return found


# Environment loading
found = check_spec("Box2D")
if found:
from carl.envs.gymnasium.box2d import *

__all__ += ["CARLBipedalWalker", "CARLLunarLander", "CARLVehicleRacing"]

0 comments on commit afbd479

Please sign in to comment.