-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* sketch * initial draft * working onnx * working pip * embed metadata * use dataclass * fix logging * format * comment * add imu link * add visual placement and proper index, first training * code restructuring * e2e works but the model is not behavior the same * working setup * Revert "e2e works but the model is not behavior the same" This reverts commit 606d791. * Revert "Merge branch 'pawel/imu_setup' of github.com:kscalelabs/sim into onnx_setup" This reverts commit 10bdc9b, reversing changes made to 606d791. * e2e works but the model is not behavior the same * format --------- Co-authored-by: WT-MM <[email protected]>
- Loading branch information
1 parent
3097ded
commit ae8681a
Showing
7 changed files
with
435 additions
and
157 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
"""Script to convert weights to Rust-compatible format.""" | ||
|
||
import re | ||
from dataclasses import dataclass, fields | ||
from io import BytesIO | ||
from typing import List, Optional, Tuple | ||
|
||
import onnx | ||
import onnxruntime as ort | ||
import torch | ||
from scripts.create_mjcf import load_embodiment | ||
from torch import Tensor, nn | ||
from torch.distributions import Normal | ||
|
||
|
||
@dataclass | ||
class ActorCfg: | ||
embodiment: str = "stompypro" | ||
cycle_time: float = 0.4 # Cycle time for sinusoidal command input | ||
action_scale: float = 0.25 # Scale for actions | ||
lin_vel_scale: float = 2.0 # Scale for linear velocity | ||
ang_vel_scale: float = 1.0 # Scale for angular velocity | ||
quat_scale: float = 1.0 # Scale for quaternion | ||
dof_pos_scale: float = 1.0 # Scale for joint positions | ||
dof_vel_scale: float = 0.05 # Scale for joint velocities | ||
frame_stack: int = 15 # Number of frames to stack for the policy input | ||
clip_observations: float = 18.0 # Clip observations to this value | ||
clip_actions: float = 18.0 # Clip actions to this value | ||
|
||
|
||
class ActorCritic(nn.Module): | ||
def __init__( | ||
self, | ||
num_actor_obs: int, | ||
num_critic_obs: int, | ||
num_actions: int, | ||
actor_hidden_dims: List[int] = [256, 256, 256], | ||
critic_hidden_dims: List[int] = [256, 256, 256], | ||
init_noise_std: float = 1.0, | ||
activation: nn.Module = nn.ELU(), | ||
) -> None: | ||
super(ActorCritic, self).__init__() | ||
|
||
mlp_input_dim_a = num_actor_obs | ||
mlp_input_dim_c = num_critic_obs | ||
|
||
# Policy function. | ||
actor_layers = [] | ||
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) | ||
actor_layers.append(activation) | ||
for dim_i in range(len(actor_hidden_dims)): | ||
if dim_i == len(actor_hidden_dims) - 1: | ||
actor_layers.append(nn.Linear(actor_hidden_dims[dim_i], num_actions)) | ||
else: | ||
actor_layers.append(nn.Linear(actor_hidden_dims[dim_i], actor_hidden_dims[dim_i + 1])) | ||
actor_layers.append(activation) | ||
self.actor = nn.Sequential(*actor_layers) | ||
|
||
# Value function. | ||
critic_layers = [] | ||
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) | ||
critic_layers.append(activation) | ||
for dim_i in range(len(critic_hidden_dims)): | ||
if dim_i == len(critic_hidden_dims) - 1: | ||
critic_layers.append(nn.Linear(critic_hidden_dims[dim_i], 1)) | ||
else: | ||
critic_layers.append(nn.Linear(critic_hidden_dims[dim_i], critic_hidden_dims[dim_i + 1])) | ||
critic_layers.append(activation) | ||
self.critic = nn.Sequential(*critic_layers) | ||
|
||
# Action noise. | ||
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) | ||
self.distribution = None | ||
|
||
# Disable args validation for speedup. | ||
Normal.set_default_validate_args = False | ||
|
||
|
||
class Actor(nn.Module): | ||
"""Actor model. | ||
Parameters: | ||
policy: The policy network. | ||
cfg: The configuration for the actor. | ||
""" | ||
|
||
def __init__(self, policy: nn.Module, cfg: ActorCfg) -> None: | ||
super().__init__() | ||
|
||
self.robot = load_embodiment(cfg.embodiment) | ||
|
||
# Policy config | ||
default_dof_pos_dict = self.robot.default_standing() | ||
self.num_actions = len(self.robot.all_joints()) | ||
self.frame_stack = cfg.frame_stack | ||
|
||
# 11 is the number of single observation features - 6 from IMU, 5 from command input | ||
# 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions) | ||
self.num_single_obs = 11 + self.num_actions * 3 | ||
self.num_observations = int(self.frame_stack * self.num_single_obs) | ||
|
||
self.policy = policy | ||
|
||
# This is the policy reference joint positions and should be the same order as the policy and mjcf file. | ||
# CURRENTLY NOT USED IN FORWARD PASS TO MAKE MORE GENERALIZEABLE FOR REAL AND SIM | ||
self.default_dof_pos = torch.tensor(list(default_dof_pos_dict.values())) | ||
|
||
self.action_scale = cfg.action_scale | ||
self.lin_vel_scale = cfg.lin_vel_scale | ||
self.ang_vel_scale = cfg.ang_vel_scale | ||
self.quat_scale = cfg.quat_scale | ||
self.dof_pos_scale = cfg.dof_pos_scale | ||
self.dof_vel_scale = cfg.dof_vel_scale | ||
|
||
self.clip_observations = cfg.clip_observations | ||
self.clip_actions = cfg.clip_actions | ||
|
||
self.cycle_time = cfg.cycle_time | ||
|
||
def get_init_buffer(self) -> Tensor: | ||
return torch.zeros(self.num_observations) | ||
|
||
def forward( | ||
self, | ||
x_vel: Tensor, # x-coordinate of the target velocity | ||
y_vel: Tensor, # y-coordinate of the target velocity | ||
rot: Tensor, # target angular velocity | ||
t: Tensor, # current policy time (sec) | ||
dof_pos: Tensor, # current angular position of the DoFs relative to default | ||
dof_vel: Tensor, # current angular velocity of the DoFs | ||
prev_actions: Tensor, # previous actions taken by the model | ||
imu_ang_vel: Tensor, # angular velocity of the IMU | ||
imu_euler_xyz: Tensor, # euler angles of the IMU | ||
buffer: Tensor, # buffer of previous observations | ||
) -> Tuple[Tensor, Tensor, Tensor]: | ||
"""Runs the actor model forward pass. | ||
Args: | ||
x_vel: The x-coordinate of the target velocity, with shape (1). | ||
y_vel: The y-coordinate of the target velocity, with shape (1). | ||
rot: The target angular velocity, with shape (1). | ||
t: The current policy time step, with shape (1). | ||
dof_pos: The current angular position of the DoFs relative to default, with shape (num_actions). | ||
dof_vel: The current angular velocity of the DoFs, with shape (num_actions). | ||
prev_actions: The previous actions taken by the model, with shape (num_actions). | ||
imu_ang_vel: The angular velocity of the IMU, with shape (3), | ||
in radians per second. If IMU is not used, can be all zeros. | ||
imu_euler_xyz: The euler angles of the IMU, with shape (3), | ||
in radians. "XYZ" means (roll, pitch, yaw). If IMU is not used, | ||
can be all zeros. | ||
buffer: The buffer of previous actions, with shape (frame_stack * num_single_obs). This is | ||
the return value of the previous forward pass. On the first | ||
pass, it should be all zeros. | ||
Returns: | ||
actions_scaled: The actions to take, with shape (num_actions), scaled by the action_scale. | ||
actions: The actions to take, with shape (num_actions). | ||
x: The new buffer of observations, with shape (frame_stack * num_single_obs). | ||
""" | ||
sin_pos = torch.sin(2 * torch.pi * t / self.cycle_time) | ||
cos_pos = torch.cos(2 * torch.pi * t / self.cycle_time) | ||
|
||
# Construct command input | ||
command_input = torch.cat( | ||
( | ||
sin_pos, | ||
cos_pos, | ||
x_vel * self.lin_vel_scale, | ||
y_vel * self.lin_vel_scale, | ||
rot * self.ang_vel_scale, | ||
), | ||
dim=0, | ||
) | ||
|
||
# Calculate current position and velocity observations | ||
q = dof_pos * self.dof_pos_scale | ||
dq = dof_vel * self.dof_vel_scale | ||
|
||
# Construct new observation | ||
new_x = torch.cat( | ||
( | ||
command_input, | ||
q, | ||
dq, | ||
prev_actions, | ||
imu_ang_vel * self.ang_vel_scale, | ||
imu_euler_xyz * self.quat_scale, | ||
), | ||
dim=0, | ||
) | ||
|
||
# Clip the inputs | ||
new_x = torch.clamp(new_x, -self.clip_observations, self.clip_observations) | ||
|
||
# Add the new frame to the buffer | ||
x = torch.cat((buffer, new_x), dim=0) | ||
# Pop the oldest frame | ||
x = x[self.num_single_obs :] | ||
|
||
policy_input = x.unsqueeze(0) | ||
|
||
# Get actions from the policy | ||
actions = self.policy(policy_input).squeeze(0) | ||
actions_scaled = actions * self.action_scale | ||
|
||
return actions_scaled, actions, x | ||
|
||
|
||
def convert_model_to_onnx(model_path: str, cfg: ActorCfg, save_path: Optional[str] = None) -> ort.InferenceSession: | ||
"""Converts a PyTorch model to a ONNX format. | ||
Args: | ||
model_path: Path to the PyTorch model. | ||
cfg: The configuration for the actor. | ||
save_path: Path to save the ONNX model. | ||
Returns: | ||
An ONNX inference session. | ||
""" | ||
all_weights = torch.load(model_path, map_location="cpu", weights_only=True) | ||
weights = all_weights["model_state_dict"] | ||
num_actor_obs = weights["actor.0.weight"].shape[1] | ||
num_critic_obs = weights["critic.0.weight"].shape[1] | ||
num_actions = weights["std"].shape[0] | ||
actor_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"actor\.\d+\.weight", k)] | ||
critic_hidden_dims = [v.shape[0] for k, v in weights.items() if re.match(r"critic\.\d+\.weight", k)] | ||
actor_hidden_dims = actor_hidden_dims[:-1] | ||
critic_hidden_dims = critic_hidden_dims[:-1] | ||
|
||
ac_model = ActorCritic(num_actor_obs, num_critic_obs, num_actions, actor_hidden_dims, critic_hidden_dims) | ||
ac_model.load_state_dict(weights) | ||
|
||
a_model = Actor(ac_model.actor, cfg) | ||
|
||
# Gets the model input tensors. | ||
x_vel = torch.randn(1) | ||
y_vel = torch.randn(1) | ||
rot = torch.randn(1) | ||
t = torch.randn(1) | ||
dof_pos = torch.randn(a_model.num_actions) | ||
dof_vel = torch.randn(a_model.num_actions) | ||
prev_actions = torch.randn(a_model.num_actions) | ||
imu_ang_vel = torch.randn(3) | ||
imu_euler_xyz = torch.randn(3) | ||
buffer = a_model.get_init_buffer() | ||
input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer) | ||
|
||
jit_model = torch.jit.script(a_model) | ||
|
||
# Export the model to a buffer | ||
buffer = BytesIO() | ||
torch.onnx.export(jit_model, input_tensors, buffer) | ||
buffer.seek(0) | ||
|
||
# Load the model as an onnx model | ||
model_proto = onnx.load_model(buffer) | ||
|
||
# Add sim2sim metadata | ||
robot_effort = list(a_model.robot.effort().values()) | ||
robot_stiffness = list(a_model.robot.stiffness().values()) | ||
robot_damping = list(a_model.robot.damping().values()) | ||
num_actions = a_model.num_actions | ||
num_observations = a_model.num_observations | ||
|
||
for field_name, field in [ | ||
("robot_effort", robot_effort), | ||
("robot_stiffness", robot_stiffness), | ||
("robot_damping", robot_damping), | ||
("num_actions", num_actions), | ||
("num_observations", num_observations), | ||
]: | ||
meta = model_proto.metadata_props.add() | ||
meta.key = field_name | ||
meta.value = str(field) | ||
|
||
# Add the configuration of the model | ||
for field in fields(cfg): | ||
value = getattr(cfg, field.name) | ||
meta = model_proto.metadata_props.add() | ||
meta.key = field.name | ||
meta.value = str(value) | ||
|
||
if save_path: | ||
onnx.save_model(model_proto, save_path) | ||
|
||
# Convert model to bytes | ||
buffer2 = BytesIO() | ||
onnx.save_model(model_proto, buffer2) | ||
buffer2.seek(0) | ||
|
||
return ort.InferenceSession(buffer2.read()) | ||
|
||
|
||
if __name__ == "__main__": | ||
convert_model_to_onnx("model_3000.pt", ActorCfg(), "policy.onnx") |
Oops, something went wrong.