Skip to content

Commit

Permalink
Merge pull request #91 from salesforce/classic_control
Browse files Browse the repository at this point in the history
Classic control
  • Loading branch information
Emerald01 authored Dec 12, 2023
2 parents 93df7a0 + cce9889 commit f7120e0
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 3 deletions.
109 changes: 109 additions & 0 deletions example_envs/single_agent/classic_control/acrobot/acrobot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np
from warp_drive.utils.constants import Constants
from warp_drive.utils.data_feed import DataFeed
from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext

from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent
from gym.envs.classic_control.acrobot import AcrobotEnv

_OBSERVATIONS = Constants.OBSERVATIONS
_ACTIONS = Constants.ACTIONS
_REWARDS = Constants.REWARDS


class ClassicControlAcrobotEnv(SingleAgentEnv):

name = "ClassicControlAcrobotEnv"

def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None):
super().__init__(episode_length, env_backend, reset_pool_size, seed=seed)

self.gym_env = AcrobotEnv()

self.action_space = map_to_single_agent(self.gym_env.action_space)
self.observation_space = map_to_single_agent(self.gym_env.observation_space)

def step(self, action=None):
self.timestep += 1
action = get_action_for_single_agent(action)
observation, reward, terminated, _, _ = self.gym_env.step(action)

obs = map_to_single_agent(observation)
rew = map_to_single_agent(reward)
done = {"__all__": self.timestep >= self.episode_length or terminated}
info = {}

return obs, rew, done, info

def reset(self):
self.timestep = 0
if self.reset_pool_size < 2:
# we use a fixed initial state all the time
initial_obs, _ = self.gym_env.reset(seed=self.seed)
else:
initial_obs, _ = self.gym_env.reset(seed=None)
obs = map_to_single_agent(initial_obs)

return obs


class CUDAClassicControlAcrobotEnv(ClassicControlAcrobotEnv, CUDAEnvironmentContext):

def get_data_dictionary(self):
data_dict = DataFeed()
# the reset function returns the initial observation which is a processed tuple from state
# so we will call env.state to access the initial state
self.gym_env.reset(seed=self.seed)
initial_state = self.gym_env.state

if self.reset_pool_size < 2:
data_dict.add_data(
name="state",
data=np.atleast_2d(initial_state),
save_copy_and_apply_at_reset=True,
)
else:
data_dict.add_data(
name="state",
data=np.atleast_2d(initial_state),
save_copy_and_apply_at_reset=False,
)
return data_dict

def get_tensor_dictionary(self):
tensor_dict = DataFeed()
return tensor_dict

def get_reset_pool_dictionary(self):
reset_pool_dict = DataFeed()
if self.reset_pool_size >= 2:
state_reset_pool = []
for _ in range(self.reset_pool_size):
self.gym_env.reset(seed=None)
initial_state = self.gym_env.state
state_reset_pool.append(np.atleast_2d(initial_state))
state_reset_pool = np.stack(state_reset_pool, axis=0)
assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 4

reset_pool_dict.add_pool_for_reset(name="state_reset_pool",
data=state_reset_pool,
reset_target="state")
return reset_pool_dict

def step(self, actions=None):
self.timestep += 1
args = [
"state",
_ACTIONS,
"_done_",
_REWARDS,
_OBSERVATIONS,
"_timestep_",
("episode_length", "meta"),
]
if self.env_backend == "numba":
self.cuda_step[
self.cuda_function_manager.grid, self.cuda_function_manager.block
](*self.cuda_step_function_feed(args))
else:
raise Exception("CUDAClassicControlAcrobotEnv expects env_backend = 'numba' ")
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import numba
import numba.cuda as numba_driver
import numpy as np
import math

AVAIL_TORQUE = np.array([-1.0, 0.0, 1.0])

LINK_LENGTH_1 = 1.0 # [m]
LINK_LENGTH_2 = 1.0 # [m]
LINK_MASS_1 = 1.0 #: [kg] mass of link 1
LINK_MASS_2 = 1.0 #: [kg] mass of link 2
LINK_COM_POS_1 = 0.5 #: [m] position of the center of mass of link 1
LINK_COM_POS_2 = 0.5 #: [m] position of the center of mass of link 2
LINK_MOI = 1.0 #: moments of inertia for both links

pi = 3.1415926535897932384626433
MAX_VEL_1 = 12.566370614359172 # 4 * pi
MAX_VEL_2 = 28.274333882308138 # 9 * pi


@numba_driver.jit
def NumbaClassicControlAcrobotEnvStep(
state_arr,
action_arr,
done_arr,
reward_arr,
observation_arr,
env_timestep_arr,
episode_length):

kEnvId = numba_driver.blockIdx.x
kThisAgentId = numba_driver.threadIdx.x

TORQUE = numba_driver.const.array_like(AVAIL_TORQUE)

assert kThisAgentId == 0, "We only have one agent per environment"

env_timestep_arr[kEnvId] += 1

assert 0 < env_timestep_arr[kEnvId] <= episode_length

reward_arr[kEnvId, kThisAgentId] = -1.0

action = action_arr[kEnvId, kThisAgentId, 0]

torque = TORQUE[action]

ns = numba_driver.local.array(shape=4, dtype=numba.float32)
rk4(state_arr[kEnvId, kThisAgentId], torque, ns)

ns[0] = wrap(ns[0], -pi, pi)
ns[1] = wrap(ns[1], -pi, pi)
ns[2] = bound(ns[2], -MAX_VEL_1, MAX_VEL_1)
ns[3] = bound(ns[3], -MAX_VEL_2, MAX_VEL_2)

for i in range(4):
state_arr[kEnvId, kThisAgentId, i] = ns[i]

terminated = _terminal(state_arr, kEnvId, kThisAgentId)
if terminated:
reward_arr[kEnvId, kThisAgentId] = 0.0

_get_ob(state_arr, observation_arr, kEnvId, kThisAgentId)

if env_timestep_arr[kEnvId] == episode_length or terminated:
done_arr[kEnvId] = 1


@numba_driver.jit(device=True)
def _dsdt(state, torque, derivatives):
m1 = LINK_MASS_1
m2 = LINK_MASS_2
l1 = LINK_LENGTH_1
lc1 = LINK_COM_POS_1
lc2 = LINK_COM_POS_2
I1 = LINK_MOI
I2 = LINK_MOI
g = 9.8
a = torque
theta1 = state[0]
theta2 = state[1]
dtheta1 = state[2]
dtheta2 = state[3]
d1 = (
m1 * lc1 ** 2
+ m2 * (l1 ** 2 + lc2 ** 2 + 2 * l1 * lc2 * math.cos(theta2))
+ I1
+ I2
)
d2 = m2 * (lc2 ** 2 + l1 * lc2 * math.cos(theta2)) + I2
phi2 = m2 * lc2 * g * math.cos(theta1 + theta2 - pi / 2)
phi1 = (
-m2 * l1 * lc2 * dtheta2 ** 2 * math.sin(theta2)
- 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * math.sin(theta2)
+ (m1 * lc1 + m2 * l1) * g * math.cos(theta1 - pi / 2)
+ phi2
)

ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * math.sin(theta2) - phi2
) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
ddtheta1 = -(d2 * ddtheta2 + phi1) / d1

derivatives[0] = dtheta1
derivatives[1] = dtheta2
derivatives[2] = ddtheta1
derivatives[3] = ddtheta2


@numba_driver.jit(device=True)
def rk4(state, torque, yout):
dt = 0.2
dt2 = 0.1
k1 = numba_driver.local.array(shape=4, dtype=numba.float32)
_dsdt(state, torque, k1)
k1_update = numba_driver.local.array(shape=4, dtype=numba.float32)
for i in range(4):
k1_update[i] = state[i] + k1[i] * dt2
k2 = numba_driver.local.array(shape=4, dtype=numba.float32)
_dsdt(k1_update, torque, k2)
k2_update = numba_driver.local.array(shape=4, dtype=numba.float32)
for i in range(4):
k2_update[i] = state[i] + k2[i] * dt2
k3 = numba_driver.local.array(shape=4, dtype=numba.float32)
_dsdt(k2_update, torque, k3)
k3_update = numba_driver.local.array(shape=4, dtype=numba.float32)
for i in range(4):
k3_update[i] = state[i] + k3[i] * dt
k4 = numba_driver.local.array(shape=4, dtype=numba.float32)
_dsdt(k3_update, torque, k4)

for i in range(4):
yout[i] = state[i] + dt / 6.0 * (k1[i] + 2 * k2[i] + 2 * k3[i] + k4[i])


@numba_driver.jit(device=True)
def wrap(x, m, M):
diff = M - m
while x > M:
x = x - diff
while x < m:
x = x + diff
return x


@numba_driver.jit(device=True)
def bound(x, m, M):
return min(max(x, m), M)


@numba_driver.jit(device=True)
def _terminal(state_arr, kEnvId, kThisAgentId):
state = state_arr[kEnvId, kThisAgentId]
return bool(-math.cos(state[0]) - math.cos(state[1] + state[0]) > 1.0)


@numba_driver.jit(device=True)
def _get_ob(
state_arr,
observation_arr,
kEnvId,
kThisAgentId,):
state = state_arr[kEnvId, kThisAgentId]
observation_arr[kEnvId, kThisAgentId, 0] = math.cos(state[0])
observation_arr[kEnvId, kThisAgentId, 1] = math.sin(state[0])
observation_arr[kEnvId, kThisAgentId, 2] = math.cos(state[1])
observation_arr[kEnvId, kThisAgentId, 3] = math.sin(state[1])
observation_arr[kEnvId, kThisAgentId, 4] = state[2]
observation_arr[kEnvId, kThisAgentId, 5] = state[3]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name="rl-warp-drive",
version="2.6.1",
version="2.6.2",
author="Tian Lan, Sunil Srinivasa, Brenton Chu, Stephan Zheng",
author_email="[email protected]",
description="Framework for fast end-to-end "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import unittest
import numpy as np
import torch

from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU
from example_envs.single_agent.classic_control.acrobot.acrobot import \
ClassicControlAcrobotEnv, CUDAClassicControlAcrobotEnv
from warp_drive.env_wrapper import EnvWrapper


env_configs = {
"test1": {
"episode_length": 200,
"reset_pool_size": 0,
"seed": 54231,
},
}


class MyTestCase(unittest.TestCase):
"""
CPU v GPU consistency unit tests
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.testing_class = EnvironmentCPUvsGPU(
cpu_env_class=ClassicControlAcrobotEnv,
cuda_env_class=CUDAClassicControlAcrobotEnv,
env_configs=env_configs,
gpu_env_backend="numba",
num_envs=5,
num_episodes=2,
)

def test_env_consistency(self):
try:
self.testing_class.test_env_reset_and_step()
except AssertionError:
self.fail("ClassicControlAcrobotEnv environment consistency tests failed")

def test_reset_pool(self):
env_wrapper = EnvWrapper(
env_obj=CUDAClassicControlAcrobotEnv(episode_length=100, reset_pool_size=8),
num_envs=3,
env_backend="numba",
)
env_wrapper.reset_all_envs()
env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345)
self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool")

# squeeze() the agent dimension which is 1 always
state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze()

reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device(
env_wrapper.cuda_data_manager.get_reset_pool("state"))
reset_pool_mean = reset_pool.mean(axis=0).squeeze()

self.assertTrue(reset_pool.std(axis=0).mean() > 1e-4)

env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy(
np.array([1, 1, 0])
).cuda()

state_values = {0: [], 1: [], 2: []}
for _ in range(10000):
env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False)
res = env_wrapper.cuda_data_manager.pull_data_from_device("state")
state_values[0].append(res[0])
state_values[1].append(res[1])
state_values[2].append(res[2])

state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze()
state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze()
state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze()

for i in range(len(reset_pool_mean)):
self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i]),
f"sampled mean: {state_values_env0_mean[i]}, expected mean: {reset_pool_mean[i]}")
self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i]),
f"sampled mean: {state_values_env1_mean[i]}, expected mean: {reset_pool_mean[i]}")
self.assertTrue(
np.absolute(
state_values_env2_mean[i] - state_after_initial_reset[0][i]
) < 0.001 * abs(state_after_initial_reset[0][i])
)


Loading

0 comments on commit f7120e0

Please sign in to comment.