forked from PaddlePaddle/PARL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
simulator_client.py
115 lines (95 loc) · 4.25 KB
/
simulator_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import grpc
import json
import numpy as np
import simulator_pb2
import simulator_pb2_grpc
from args import get_client_args
from env_wrapper import FrameSkip, ActionScale, PelvisBasedObs, MAXTIME_LIMIT, CustomR2Env, RunFastestReward, FixedTargetSpeedReward, Round2Reward
from osim.env import ProstheticsEnv
from parl.utils import logger
ProstheticsEnv.time_limit = MAXTIME_LIMIT
class Worker(object):
def __init__(self, server_ip='localhost', server_port=5007):
if args.ident is not None:
self.worker_id = args.ident
else:
self.worker_id = np.random.randint(int(1e18))
self.address = '{}:{}'.format(server_ip, server_port)
random_seed = int(self.worker_id % int(1e9))
np.random.seed(random_seed)
env = ProstheticsEnv(visualize=False, seed=random_seed)
env.change_model(
model='3D', difficulty=1, prosthetic=True, seed=random_seed)
env.spec.timestep_limit = MAXTIME_LIMIT
env = CustomR2Env(env)
if args.reward_type == 'RunFastest':
env = RunFastestReward(env)
elif args.reward_type == 'FixedTargetSpeed':
env = FixedTargetSpeedReward(
env, args.target_v, args.act_penalty_lowerbound,
args.act_penalty_coeff, args.vel_penalty_coeff)
elif args.reward_type == 'Round2':
env = Round2Reward(env, args.act_penalty_lowerbound,
args.act_penalty_coeff, args.vel_penalty_coeff)
else:
assert False, 'Not supported reward type!'
env = FrameSkip(env, 4)
env = ActionScale(env)
self.env = PelvisBasedObs(env)
def run(self):
observation = self.env.reset(project=False, stage=args.stage)
reward = 0
done = False
info = {'shaping_reward': 0.0}
info['first'] = True
with grpc.insecure_channel(self.address) as channel:
stub = simulator_pb2_grpc.SimulatorStub(channel)
while True:
response = stub.Send(
simulator_pb2.Request(
observation=observation,
reward=reward,
done=done,
info=json.dumps(info),
id=self.worker_id))
extra = json.loads(response.extra)
if 'reset' in extra and extra['reset']:
logger.info('Server require to reset!')
observation = self.env.reset(
project=False, stage=args.stage)
reward = 0
done = False
info = {'shaping_reward': 0.0}
continue
if 'shutdown' in extra and extra['shutdown']:
break
action = np.array(response.action)
next_observation, reward, done, info = self.env.step(
action, project=False)
# debug info
if args.debug:
logger.info("dim:{} obs:{} act:{} reward:{} done:{} info:{}".format(\
len(observation), np.sum(observation), np.sum(action), reward, done, info))
observation = next_observation
if done:
observation = self.env.reset(
project=False, stage=args.stage)
# the last observation should be used to compute append_value in simulator_server
info['last_obs'] = next_observation.tolist()
if __name__ == '__main__':
args = get_client_args()
worker = Worker(server_ip=args.ip, server_port=args.port)
worker.run()