Skip to content

Commit

Permalink
add draft of the real policy
Browse files Browse the repository at this point in the history
  • Loading branch information
budzianowski committed Jul 12, 2024
1 parent e3ff6f5 commit 8ef567f
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions sim/deploy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Run example:
mjpython sim/deploy/run.py --load_model sim/deploy/tests/walking_policy.pt --world MUJOCO
python sim/deploy/run.py --load_model MODEL_WEIGHTS --world REAL
"""

import argparse
Expand All @@ -23,6 +25,11 @@
from sim.env import stompy_mjcf_path
from sim.stompy.joints import StompyFixed

import time
from firmware.cpp.imu.imu import IMU
from firmware.scripts.robot_controller import Robot # TODO:(ved) move this to a more appropriate location
import torch


class Worlds(Enum):
MUJOCO = "SIM"
Expand Down Expand Up @@ -50,6 +57,41 @@ def get_observation(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarra
pass


class Real(World):
def __init__(self, cfg: RobotConfig):
self.robot = Robot("legs")
self.robot.zero_out() # TODO: (Ved - zero out the legs)
self.model = torch.load(cfg.robot_model_path) # TODO: (Allen/Isaac - load the model)
self.imu = IMU(1) # TODO: (Weasley -load the imu)
self.state = None

def step(self, observation: np.ndarray) -> None:
"""Performs a simulation in the real world."""
tau = self.model(observation) # TODO: (Allen/Isaac - run the model)
self.robot.set_position(tau) # TODO: (Ved - set the position of the robot)

def get_observation(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Extracts an observation from the world state.
Returns:
A tuple containing the following:
- dof_pos: The joint positions.
- dof_vel: The joint velocities.
- orientation: The orientation of the robot.
- ang_vel: The angular velocity of the robot.
"""
ang_vel, orientation = self.imu.step()
dof_pos = self.robot.get_position()
dof_vel = self.robot.get_velocity()
return (dof_pos, dof_vel, orientation, ang_vel)

def simulate(self, policy=None) -> None:
for step in tqdm(range(int(cfg.duration / self.cfg.dt)), desc="Simulating..."):
obs = self.get_observation()
action = self.step(obs)
time.sleep(self.cfg.dt)


class MujocoWorld(World):
"""Simulated world using MuJoCo.
Expand Down

0 comments on commit 8ef567f

Please sign in to comment.