-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtesting.py
42 lines (33 loc) · 853 Bytes
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import gym
import ArmEnv
import time
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
""" env = gym.make('PointToPoint-v0',gui=True,mode='T')
env.reset()
rews = []
print(env.observation_space)
print(env.action_space)
for i in range(100):
action = env.action_space.sample()
_, rew, _, _ = env.step(action)
rews.append(rew)
time.sleep(0.25)
print(action,rew) """
env = gym.make('PointToPoint-v0',gui=True,mode='P')
model = PPO('MlpPolicy',env,verbose=1,device='cuda')
obs = env.reset()
print('Observation:',obs)
dones = False
rews = []
count = 0
while(True):
count += 1
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
print(action,rewards)
rews.append(rewards)
if(dones):
break
#print(count,end='\r')
print("Cumulative REWARD:",sum(rews))