-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #91 from salesforce/classic_control
Classic control
- Loading branch information
Showing
7 changed files
with
423 additions
and
3 deletions.
There are no files selected for viewing
109 changes: 109 additions & 0 deletions
109
example_envs/single_agent/classic_control/acrobot/acrobot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import numpy as np | ||
from warp_drive.utils.constants import Constants | ||
from warp_drive.utils.data_feed import DataFeed | ||
from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext | ||
|
||
from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent | ||
from gym.envs.classic_control.acrobot import AcrobotEnv | ||
|
||
_OBSERVATIONS = Constants.OBSERVATIONS | ||
_ACTIONS = Constants.ACTIONS | ||
_REWARDS = Constants.REWARDS | ||
|
||
|
||
class ClassicControlAcrobotEnv(SingleAgentEnv): | ||
|
||
name = "ClassicControlAcrobotEnv" | ||
|
||
def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None): | ||
super().__init__(episode_length, env_backend, reset_pool_size, seed=seed) | ||
|
||
self.gym_env = AcrobotEnv() | ||
|
||
self.action_space = map_to_single_agent(self.gym_env.action_space) | ||
self.observation_space = map_to_single_agent(self.gym_env.observation_space) | ||
|
||
def step(self, action=None): | ||
self.timestep += 1 | ||
action = get_action_for_single_agent(action) | ||
observation, reward, terminated, _, _ = self.gym_env.step(action) | ||
|
||
obs = map_to_single_agent(observation) | ||
rew = map_to_single_agent(reward) | ||
done = {"__all__": self.timestep >= self.episode_length or terminated} | ||
info = {} | ||
|
||
return obs, rew, done, info | ||
|
||
def reset(self): | ||
self.timestep = 0 | ||
if self.reset_pool_size < 2: | ||
# we use a fixed initial state all the time | ||
initial_obs, _ = self.gym_env.reset(seed=self.seed) | ||
else: | ||
initial_obs, _ = self.gym_env.reset(seed=None) | ||
obs = map_to_single_agent(initial_obs) | ||
|
||
return obs | ||
|
||
|
||
class CUDAClassicControlAcrobotEnv(ClassicControlAcrobotEnv, CUDAEnvironmentContext): | ||
|
||
def get_data_dictionary(self): | ||
data_dict = DataFeed() | ||
# the reset function returns the initial observation which is a processed tuple from state | ||
# so we will call env.state to access the initial state | ||
self.gym_env.reset(seed=self.seed) | ||
initial_state = self.gym_env.state | ||
|
||
if self.reset_pool_size < 2: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=True, | ||
) | ||
else: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=False, | ||
) | ||
return data_dict | ||
|
||
def get_tensor_dictionary(self): | ||
tensor_dict = DataFeed() | ||
return tensor_dict | ||
|
||
def get_reset_pool_dictionary(self): | ||
reset_pool_dict = DataFeed() | ||
if self.reset_pool_size >= 2: | ||
state_reset_pool = [] | ||
for _ in range(self.reset_pool_size): | ||
self.gym_env.reset(seed=None) | ||
initial_state = self.gym_env.state | ||
state_reset_pool.append(np.atleast_2d(initial_state)) | ||
state_reset_pool = np.stack(state_reset_pool, axis=0) | ||
assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 4 | ||
|
||
reset_pool_dict.add_pool_for_reset(name="state_reset_pool", | ||
data=state_reset_pool, | ||
reset_target="state") | ||
return reset_pool_dict | ||
|
||
def step(self, actions=None): | ||
self.timestep += 1 | ||
args = [ | ||
"state", | ||
_ACTIONS, | ||
"_done_", | ||
_REWARDS, | ||
_OBSERVATIONS, | ||
"_timestep_", | ||
("episode_length", "meta"), | ||
] | ||
if self.env_backend == "numba": | ||
self.cuda_step[ | ||
self.cuda_function_manager.grid, self.cuda_function_manager.block | ||
](*self.cuda_step_function_feed(args)) | ||
else: | ||
raise Exception("CUDAClassicControlAcrobotEnv expects env_backend = 'numba' ") |
168 changes: 168 additions & 0 deletions
168
example_envs/single_agent/classic_control/acrobot/acrobot_step_numba.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import numba | ||
import numba.cuda as numba_driver | ||
import numpy as np | ||
import math | ||
|
||
AVAIL_TORQUE = np.array([-1.0, 0.0, 1.0]) | ||
|
||
LINK_LENGTH_1 = 1.0 # [m] | ||
LINK_LENGTH_2 = 1.0 # [m] | ||
LINK_MASS_1 = 1.0 #: [kg] mass of link 1 | ||
LINK_MASS_2 = 1.0 #: [kg] mass of link 2 | ||
LINK_COM_POS_1 = 0.5 #: [m] position of the center of mass of link 1 | ||
LINK_COM_POS_2 = 0.5 #: [m] position of the center of mass of link 2 | ||
LINK_MOI = 1.0 #: moments of inertia for both links | ||
|
||
pi = 3.1415926535897932384626433 | ||
MAX_VEL_1 = 12.566370614359172 # 4 * pi | ||
MAX_VEL_2 = 28.274333882308138 # 9 * pi | ||
|
||
|
||
@numba_driver.jit | ||
def NumbaClassicControlAcrobotEnvStep( | ||
state_arr, | ||
action_arr, | ||
done_arr, | ||
reward_arr, | ||
observation_arr, | ||
env_timestep_arr, | ||
episode_length): | ||
|
||
kEnvId = numba_driver.blockIdx.x | ||
kThisAgentId = numba_driver.threadIdx.x | ||
|
||
TORQUE = numba_driver.const.array_like(AVAIL_TORQUE) | ||
|
||
assert kThisAgentId == 0, "We only have one agent per environment" | ||
|
||
env_timestep_arr[kEnvId] += 1 | ||
|
||
assert 0 < env_timestep_arr[kEnvId] <= episode_length | ||
|
||
reward_arr[kEnvId, kThisAgentId] = -1.0 | ||
|
||
action = action_arr[kEnvId, kThisAgentId, 0] | ||
|
||
torque = TORQUE[action] | ||
|
||
ns = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
rk4(state_arr[kEnvId, kThisAgentId], torque, ns) | ||
|
||
ns[0] = wrap(ns[0], -pi, pi) | ||
ns[1] = wrap(ns[1], -pi, pi) | ||
ns[2] = bound(ns[2], -MAX_VEL_1, MAX_VEL_1) | ||
ns[3] = bound(ns[3], -MAX_VEL_2, MAX_VEL_2) | ||
|
||
for i in range(4): | ||
state_arr[kEnvId, kThisAgentId, i] = ns[i] | ||
|
||
terminated = _terminal(state_arr, kEnvId, kThisAgentId) | ||
if terminated: | ||
reward_arr[kEnvId, kThisAgentId] = 0.0 | ||
|
||
_get_ob(state_arr, observation_arr, kEnvId, kThisAgentId) | ||
|
||
if env_timestep_arr[kEnvId] == episode_length or terminated: | ||
done_arr[kEnvId] = 1 | ||
|
||
|
||
@numba_driver.jit(device=True) | ||
def _dsdt(state, torque, derivatives): | ||
m1 = LINK_MASS_1 | ||
m2 = LINK_MASS_2 | ||
l1 = LINK_LENGTH_1 | ||
lc1 = LINK_COM_POS_1 | ||
lc2 = LINK_COM_POS_2 | ||
I1 = LINK_MOI | ||
I2 = LINK_MOI | ||
g = 9.8 | ||
a = torque | ||
theta1 = state[0] | ||
theta2 = state[1] | ||
dtheta1 = state[2] | ||
dtheta2 = state[3] | ||
d1 = ( | ||
m1 * lc1 ** 2 | ||
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * math.cos(theta2)) | ||
+ I1 | ||
+ I2 | ||
) | ||
d2 = m2 * (lc2 ** 2 + l1 * lc2 * math.cos(theta2)) + I2 | ||
phi2 = m2 * lc2 * g * math.cos(theta1 + theta2 - pi / 2) | ||
phi1 = ( | ||
-m2 * l1 * lc2 * dtheta2 ** 2 * math.sin(theta2) | ||
- 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * math.sin(theta2) | ||
+ (m1 * lc1 + m2 * l1) * g * math.cos(theta1 - pi / 2) | ||
+ phi2 | ||
) | ||
|
||
ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * math.sin(theta2) - phi2 | ||
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) | ||
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 | ||
|
||
derivatives[0] = dtheta1 | ||
derivatives[1] = dtheta2 | ||
derivatives[2] = ddtheta1 | ||
derivatives[3] = ddtheta2 | ||
|
||
|
||
@numba_driver.jit(device=True) | ||
def rk4(state, torque, yout): | ||
dt = 0.2 | ||
dt2 = 0.1 | ||
k1 = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
_dsdt(state, torque, k1) | ||
k1_update = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
for i in range(4): | ||
k1_update[i] = state[i] + k1[i] * dt2 | ||
k2 = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
_dsdt(k1_update, torque, k2) | ||
k2_update = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
for i in range(4): | ||
k2_update[i] = state[i] + k2[i] * dt2 | ||
k3 = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
_dsdt(k2_update, torque, k3) | ||
k3_update = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
for i in range(4): | ||
k3_update[i] = state[i] + k3[i] * dt | ||
k4 = numba_driver.local.array(shape=4, dtype=numba.float32) | ||
_dsdt(k3_update, torque, k4) | ||
|
||
for i in range(4): | ||
yout[i] = state[i] + dt / 6.0 * (k1[i] + 2 * k2[i] + 2 * k3[i] + k4[i]) | ||
|
||
|
||
@numba_driver.jit(device=True) | ||
def wrap(x, m, M): | ||
diff = M - m | ||
while x > M: | ||
x = x - diff | ||
while x < m: | ||
x = x + diff | ||
return x | ||
|
||
|
||
@numba_driver.jit(device=True) | ||
def bound(x, m, M): | ||
return min(max(x, m), M) | ||
|
||
|
||
@numba_driver.jit(device=True) | ||
def _terminal(state_arr, kEnvId, kThisAgentId): | ||
state = state_arr[kEnvId, kThisAgentId] | ||
return bool(-math.cos(state[0]) - math.cos(state[1] + state[0]) > 1.0) | ||
|
||
|
||
@numba_driver.jit(device=True) | ||
def _get_ob( | ||
state_arr, | ||
observation_arr, | ||
kEnvId, | ||
kThisAgentId,): | ||
state = state_arr[kEnvId, kThisAgentId] | ||
observation_arr[kEnvId, kThisAgentId, 0] = math.cos(state[0]) | ||
observation_arr[kEnvId, kThisAgentId, 1] = math.sin(state[0]) | ||
observation_arr[kEnvId, kThisAgentId, 2] = math.cos(state[1]) | ||
observation_arr[kEnvId, kThisAgentId, 3] = math.sin(state[1]) | ||
observation_arr[kEnvId, kThisAgentId, 4] = state[2] | ||
observation_arr[kEnvId, kThisAgentId, 5] = state[3] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
setup( | ||
name="rl-warp-drive", | ||
version="2.6.1", | ||
version="2.6.2", | ||
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng", | ||
author_email="[email protected]", | ||
description="Framework for fast end-to-end " | ||
|
88 changes: 88 additions & 0 deletions
88
tests/example_envs/numba_tests/single_agent/classic_control/test_acrobot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import unittest | ||
import numpy as np | ||
import torch | ||
|
||
from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU | ||
from example_envs.single_agent.classic_control.acrobot.acrobot import \ | ||
ClassicControlAcrobotEnv, CUDAClassicControlAcrobotEnv | ||
from warp_drive.env_wrapper import EnvWrapper | ||
|
||
|
||
env_configs = { | ||
"test1": { | ||
"episode_length": 200, | ||
"reset_pool_size": 0, | ||
"seed": 54231, | ||
}, | ||
} | ||
|
||
|
||
class MyTestCase(unittest.TestCase): | ||
""" | ||
CPU v GPU consistency unit tests | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.testing_class = EnvironmentCPUvsGPU( | ||
cpu_env_class=ClassicControlAcrobotEnv, | ||
cuda_env_class=CUDAClassicControlAcrobotEnv, | ||
env_configs=env_configs, | ||
gpu_env_backend="numba", | ||
num_envs=5, | ||
num_episodes=2, | ||
) | ||
|
||
def test_env_consistency(self): | ||
try: | ||
self.testing_class.test_env_reset_and_step() | ||
except AssertionError: | ||
self.fail("ClassicControlAcrobotEnv environment consistency tests failed") | ||
|
||
def test_reset_pool(self): | ||
env_wrapper = EnvWrapper( | ||
env_obj=CUDAClassicControlAcrobotEnv(episode_length=100, reset_pool_size=8), | ||
num_envs=3, | ||
env_backend="numba", | ||
) | ||
env_wrapper.reset_all_envs() | ||
env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345) | ||
self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool") | ||
|
||
# squeeze() the agent dimension which is 1 always | ||
state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze() | ||
|
||
reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device( | ||
env_wrapper.cuda_data_manager.get_reset_pool("state")) | ||
reset_pool_mean = reset_pool.mean(axis=0).squeeze() | ||
|
||
self.assertTrue(reset_pool.std(axis=0).mean() > 1e-4) | ||
|
||
env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( | ||
np.array([1, 1, 0]) | ||
).cuda() | ||
|
||
state_values = {0: [], 1: [], 2: []} | ||
for _ in range(10000): | ||
env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False) | ||
res = env_wrapper.cuda_data_manager.pull_data_from_device("state") | ||
state_values[0].append(res[0]) | ||
state_values[1].append(res[1]) | ||
state_values[2].append(res[2]) | ||
|
||
state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze() | ||
state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() | ||
state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() | ||
|
||
for i in range(len(reset_pool_mean)): | ||
self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i]), | ||
f"sampled mean: {state_values_env0_mean[i]}, expected mean: {reset_pool_mean[i]}") | ||
self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i]), | ||
f"sampled mean: {state_values_env1_mean[i]}, expected mean: {reset_pool_mean[i]}") | ||
self.assertTrue( | ||
np.absolute( | ||
state_values_env2_mean[i] - state_after_initial_reset[0][i] | ||
) < 0.001 * abs(state_after_initial_reset[0][i]) | ||
) | ||
|
||
|
Oops, something went wrong.