-
Notifications
You must be signed in to change notification settings - Fork 62
/
test.py
98 lines (82 loc) · 3.03 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
########################## FWMAV Simulation #########################
# Version 0.3
# Fan Fei Feb 2019
# Direct motor driven flapping wing MAV simulation
#######################################################################
import gym
import flappy
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import set_global_seeds
from flappy.envs.fwmav.controllers.arc_xy_arc_z import ARCController
from flappy.envs.fwmav.controllers.pid_controller import PIDController
import time
import argparse
import importlib
import numpy as np
def make_env(env_id, rank, seed=0, random_init = True, randomize_sim = True, phantom_sensor = False):
def _init():
env = gym.make(env_id)
env.config(random_init, randomize_sim, phantom_sensor)
if rank == 0:
env.enable_visualization()
env.enable_print()
env.seed(seed + rank)
return env
# set_global_seeds(seed)
return _init
class LazyModel:
def __init__(self,env,model_type):
self.action_lb = env.action_lb
self.action_ub = env.action_ub
self.observation_bound = env.observation_bound
if model_type == 'PID':
self.policy = PIDController(env.sim.dt_c)
elif model_type == 'ARC':
self.policy = ARCController(env.sim.dt_c)
else:
raise Exception('Error')
def predict(self, obs):
action = self.policy.get_action(obs[0]*self.observation_bound)
# scale action from [action_lb, action_ub] to [-1,1]
# since baseline does not support asymmetric action space
normalized_action = (action-self.action_lb)/(self.action_ub - self.action_lb)*2 - 1
action = np.array([normalized_action])
return action, None
def main(args):
env_id = 'fwmav_hover-v0'
env = DummyVecEnv([make_env(env_id, 0, random_init = args.rand_init, randomize_sim = args.rand_dynamics, phantom_sensor = args.phantom_sensor)])
if args.model_type != 'PID' and args.model_type != 'ARC':
try:
model_cls = getattr(
importlib.import_module('stable_baselines'), args.model_type)
except AttributeError:
print(args.model_type, "Error: wrong model type")
return
try:
model = model_cls.load(args.model_path)
except:
print(args.model_path, "Error: wrong model path")
else:
model = LazyModel(env.envs[0],args.model_type)
obs = env.reset()
while True:
if env.envs[0].is_sim_on == False:
env.envs[0].gui.cv.wait()
elif env.envs[0].is_sim_on:
action, _ = model.predict(obs)
obs, rewards, done, info = env.step(action)
# if done:
# obs = env.reset()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', required=True)
parser.add_argument('--model_path')
parser.add_argument(
'--policy_type', const='MlpPolicy', default='MlpPolicy', nargs='?')
parser.add_argument('--rand_init', action='store_true', default=False)
parser.add_argument('--rand_dynamics', action='store_true', default=False)
parser.add_argument('--phantom_sensor', action='store_true', default=False)
args = parser.parse_args()
main(args)