Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update kinfer module #134

Merged
merged 44 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
bafb56f
kinfer module
WT-MM Dec 30, 2024
faa9119
ppo schema
WT-MM Dec 30, 2024
6e4988d
fix typo
WT-MM Dec 30, 2024
983e010
get jit
WT-MM Dec 30, 2024
b41d3df
save
WT-MM Dec 30, 2024
f627fb3
lint
WT-MM Dec 30, 2024
d363437
move schema definition
WT-MM Dec 30, 2024
9942d79
fmt
WT-MM Dec 30, 2024
64d42c1
fix schema
WT-MM Dec 30, 2024
3f04f24
minor fixes
WT-MM Dec 30, 2024
8f904ca
fix names
WT-MM Dec 30, 2024
8f120bb
tensors
WT-MM Dec 30, 2024
31dd379
squeeze
WT-MM Dec 30, 2024
7a0a1c3
fix imports
WT-MM Dec 30, 2024
a289ee8
pin 3.8 compatible kinfer
WT-MM Dec 30, 2024
26466b8
lint
WT-MM Dec 30, 2024
d041f50
joints
WT-MM Dec 31, 2024
4a608b0
update local kinfer
WT-MM Dec 31, 2024
07a9b24
update metadata + path
WT-MM Dec 31, 2024
262198d
clean and rename
WT-MM Dec 31, 2024
0c583db
lint
WT-MM Dec 31, 2024
ce4d729
rename
WT-MM Dec 31, 2024
9b0738a
push kinfer
WT-MM Jan 1, 2025
aff4073
new names
WT-MM Jan 1, 2025
ac6a76c
policy runs
WT-MM Jan 1, 2025
30791e6
joint name ordering
WT-MM Jan 1, 2025
5283db1
stands
WT-MM Jan 1, 2025
01fdc36
clean
WT-MM Jan 1, 2025
f02bd84
lint
WT-MM Jan 1, 2025
87f0253
cycle time
WT-MM Jan 2, 2025
9cfde7f
noise to default
WT-MM Jan 2, 2025
bc56f7e
change noise behavior
WT-MM Jan 2, 2025
3574ef1
less noise
WT-MM Jan 2, 2025
63ed694
fix noise
WT-MM Jan 2, 2025
d537864
lint
WT-MM Jan 2, 2025
20add09
get name from schema
WT-MM Jan 2, 2025
07b9a9d
fix
WT-MM Jan 2, 2025
d37961d
lint
WT-MM Jan 2, 2025
34783c0
more info from schema
WT-MM Jan 2, 2025
d3ca3d0
imu
WT-MM Jan 2, 2025
dfd2263
Merge branch 'kinfer-schema' of https://github.com/kscalelabs/sim int…
WT-MM Jan 2, 2025
5705418
imu
WT-MM Jan 2, 2025
cfb2513
add imu missing warning
WT-MM Jan 2, 2025
d37834d
remove default randomization
WT-MM Jan 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ wandb/
runs/
isaacgym/
*.h5

# dev
ref/
5 changes: 4 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "third_party/isaacgym"]
path = third_party/isaacgym
url = ../../kscalelabs/isaacgym.git
url = ../../kscalelabs/isaacgym.git
[submodule "third_party/kinfer"]
path = third_party/kinfer
url = https://github.com/kscalelabs/kinfer.git
2 changes: 2 additions & 0 deletions sim/envs/base/legged_robot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#
# Copyright (c) 2024 Beijing RobotEra TECHNOLOGY CO.,LTD. All rights reserved.

import os
import sys
from enum import Enum

from sim.envs.base.base_config import BaseConfig
Expand Down
89 changes: 89 additions & 0 deletions sim/envs/humanoids/gpr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,95 @@ class env(LeggedRobotCfg.env):
episode_length_s = 24 # episode length in seconds
use_ref_actions = False

from kinfer import proto as P
WT-MM marked this conversation as resolved.
Show resolved Hide resolved

input_schema = P.IOSchema(
values=[
P.ValueSchema(
value_name="vector_command",
vector_command=P.VectorCommandSchema(
dimensions=3, # x_vel, y_vel, rot
),
),
P.ValueSchema(
value_name="timestamp",
timestamp=P.TimestampSchema(
start_seconds=0,
),
),
P.ValueSchema(
value_name="dof_pos",
joint_positions=P.JointPositionsSchema(
joint_names=Robot.joint_names(),
unit=P.JointPositionUnit.RADIANS,
),
),
P.ValueSchema(
value_name="dof_vel",
joint_velocities=P.JointVelocitiesSchema(
joint_names=Robot.joint_names(),
unit=P.JointVelocityUnit.RADIANS_PER_SECOND,
),
),
P.ValueSchema(
value_name="prev_actions",
joint_positions=P.JointPositionsSchema(
joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS
),
),
# Abusing the IMU schema to pass in euler and angular velocity instead of raw sensor data
P.ValueSchema(
value_name="imu_ang_vel",
imu=P.ImuSchema(
use_accelerometer=False,
use_gyroscope=True,
use_magnetometer=False,
),
),
P.ValueSchema(
value_name="imu_euler_xyz",
imu=P.ImuSchema(
use_accelerometer=True,
use_gyroscope=False,
use_magnetometer=False,
),
),
WT-MM marked this conversation as resolved.
Show resolved Hide resolved
P.ValueSchema(
value_name="hist_obs",
state_tensor=P.StateTensorSchema(
# 11 is the number of single observation features - 6 from IMU, 5 from command input
# 3 comes from the number of times num_actions is repeated in the observation (dof_pos, dof_vel, prev_actions)
shape=[frame_stack * (11 + NUM_JOINTS * 3)],
dtype=P.DType.FP32,
),
),
]
)

output_schema = P.IOSchema(
values=[
P.ValueSchema(
value_name="actions",
joint_positions=P.JointPositionsSchema(
joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS
),
),
P.ValueSchema(
value_name="actions_raw",
joint_positions=P.JointPositionsSchema(
joint_names=Robot.joint_names(), unit=P.JointPositionUnit.RADIANS
),
),
P.ValueSchema(
value_name="new_x",
state_tensor=P.StateTensorSchema(
shape=[frame_stack * (11 + NUM_JOINTS * 3)],
dtype=P.DType.FP32,
),
),
]
)

class safety(LeggedRobotCfg.safety):
# safety factors
pos_limit = 1.0
Expand Down
74 changes: 44 additions & 30 deletions sim/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from dataclasses import dataclass, fields
from io import BytesIO
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple

import onnx
import onnxruntime as ort
Expand Down Expand Up @@ -125,24 +125,20 @@ def get_init_buffer(self) -> Tensor:

def forward(
self,
x_vel: Tensor, # x-coordinate of the target velocity
y_vel: Tensor, # y-coordinate of the target velocity
rot: Tensor, # target angular velocity
t: Tensor, # current policy time (sec)
vector_command: Tensor, # (x_vel, y_vel, rot)
WT-MM marked this conversation as resolved.
Show resolved Hide resolved
timestamp: Tensor, # current policy time (sec)
dof_pos: Tensor, # current angular position of the DoFs relative to default
dof_vel: Tensor, # current angular velocity of the DoFs
prev_actions: Tensor, # previous actions taken by the model
imu_ang_vel: Tensor, # angular velocity of the IMU
imu_euler_xyz: Tensor, # euler angles of the IMU
buffer: Tensor, # buffer of previous observations
) -> Tuple[Tensor, Tensor, Tensor]:
hist_obs: Tensor, # buffer of previous observations
) -> Dict[str, Tensor]:
"""Runs the actor model forward pass.

Args:
x_vel: The x-coordinate of the target velocity, with shape (1).
y_vel: The y-coordinate of the target velocity, with shape (1).
rot: The target angular velocity, with shape (1).
t: The current policy time step, with shape (1).
vector_command: The target velocity vector, with shape (3). It consistes of x_vel, y_vel, and rot.
timestamp: The current policy time step, with shape (1).
dof_pos: The current angular position of the DoFs relative to default, with shape (num_actions).
dof_vel: The current angular velocity of the DoFs, with shape (num_actions).
prev_actions: The previous actions taken by the model, with shape (num_actions).
Expand All @@ -151,7 +147,7 @@ def forward(
imu_euler_xyz: The euler angles of the IMU, with shape (3),
in radians. "XYZ" means (roll, pitch, yaw). If IMU is not used,
can be all zeros.
buffer: The buffer of previous actions, with shape (frame_stack * num_single_obs). This is
state_tensor: The buffer of previous actions, with shape (frame_stack * num_single_obs). This is
the return value of the previous forward pass. On the first
pass, it should be all zeros.

Expand All @@ -160,8 +156,10 @@ def forward(
actions: The actions to take, with shape (num_actions).
x: The new buffer of observations, with shape (frame_stack * num_single_obs).
"""
sin_pos = torch.sin(2 * torch.pi * t / self.cycle_time)
cos_pos = torch.cos(2 * torch.pi * t / self.cycle_time)
sin_pos = torch.sin(2 * torch.pi * timestamp / self.cycle_time)
cos_pos = torch.cos(2 * torch.pi * timestamp / self.cycle_time)

x_vel, y_vel, rot = vector_command.split(1)

# Construct command input
command_input = torch.cat(
Expand All @@ -186,8 +184,8 @@ def forward(
q,
dq,
prev_actions,
imu_ang_vel * self.ang_vel_scale,
imu_euler_xyz * self.quat_scale,
imu_ang_vel.squeeze(0) * self.ang_vel_scale,
imu_euler_xyz.squeeze(0) * self.quat_scale,
),
dim=0,
)
Expand All @@ -196,7 +194,7 @@ def forward(
new_x = torch.clamp(new_x, -self.clip_observations, self.clip_observations)

# Add the new frame to the buffer
x = torch.cat((buffer, new_x), dim=0)
x = torch.cat((hist_obs, new_x), dim=0)
# Pop the oldest frame
x = x[self.num_single_obs :]

Expand All @@ -206,7 +204,7 @@ def forward(
actions = self.policy(policy_input).squeeze(0)
actions_scaled = actions * self.action_scale

return actions_scaled, actions, x
return {"actions": actions_scaled, "actions_raw": actions, "new_x": x}


def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, Tuple[Tensor, ...]]:
Expand Down Expand Up @@ -239,24 +237,40 @@ def get_actor_policy(model_path: str, cfg: ActorCfg) -> Tuple[nn.Module, dict, T
input_tensors = (x_vel, y_vel, rot, t, dof_pos, dof_vel, prev_actions, imu_ang_vel, imu_euler_xyz, buffer)

# Add sim2sim metadata
robot_effort = list(a_model.robot.effort().values())
robot_stiffness = list(a_model.robot.stiffness().values())
robot_damping = list(a_model.robot.damping().values())
robot = a_model.robot
robot_effort = robot.effort_mapping()
robot_stiffness = robot.stiffness_mapping()
robot_damping = robot.damping_mapping()
num_actions = a_model.num_actions
num_observations = a_model.num_observations

default_standing = list(a_model.robot.default_standing().values())
default_standing = robot.default_standing()

metadata = {
"num_actions": num_actions,
"num_observations": num_observations,
"robot_effort": robot_effort,
"robot_stiffness": robot_stiffness,
"robot_damping": robot_damping,
"default_standing": default_standing,
"sim_dt": cfg.sim_dt,
"sim_decimation": cfg.sim_decimation,
"tau_factor": cfg.tau_factor,
"action_scale": cfg.action_scale,
"lin_vel_scale": cfg.lin_vel_scale,
"ang_vel_scale": cfg.ang_vel_scale,
"quat_scale": cfg.quat_scale,
"dof_pos_scale": cfg.dof_pos_scale,
"dof_vel_scale": cfg.dof_vel_scale,
"frame_stack": cfg.frame_stack,
"clip_observations": cfg.clip_observations,
"clip_actions": cfg.clip_actions,
"joint_names": robot.joint_names(),
}

return (
a_model,
{
"robot_effort": robot_effort,
"robot_stiffness": robot_stiffness,
"robot_damping": robot_damping,
"num_actions": num_actions,
"num_observations": num_observations,
"default_standing": default_standing,
},
metadata,
input_tensors,
)

Expand Down
42 changes: 25 additions & 17 deletions sim/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import logging
import math
import os
import time
import uuid
from datetime import datetime
from typing import Any, Union

import cv2
import h5py
import numpy as np
import onnx
from isaacgym import gymapi
from kinfer.export.pytorch import export_to_onnx
from kinfer import proto as P
from kinfer.export.pytorch import export_model
from tqdm import tqdm

from sim.env import run_dir # noqa: E402
Expand All @@ -28,6 +28,7 @@
from sim.model_export import ActorCfg, get_actor_policy
from sim.utils.helpers import get_args # noqa: E402
from sim.utils.logger import Logger # noqa: E402
from sim.utils.resources import load_embodiment

import torch # special case with isort: skip comment

Expand All @@ -37,9 +38,14 @@
def export_policy_as_jit(actor_critic: Any, path: Union[str, os.PathLike]) -> None:
os.makedirs(path, exist_ok=True)
path = os.path.join(path, "policy_1.pt")
model = get_actor_jit(actor_critic)
model.save(path)


def get_actor_jit(actor_critic: Any) -> Any:
model = copy.deepcopy(actor_critic.actor).to("cpu")
traced_script_module = torch.jit.script(model)
traced_script_module.save(path)
return traced_script_module


def play(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -84,12 +90,15 @@ def play(args: argparse.Namespace) -> None:
export_policy_as_jit(ppo_runner.alg.actor_critic, path)
print("Exported policy as jit script to: ", path)

# export policy as a onnx module (used to run it on web)
if env_cfg.env.input_schema is None or env_cfg.env.output_schema is None:
raise ValueError("Input or output schema is None")

# Create the full model schema
model_schema = P.ModelSchema(input_schema=env_cfg.env.input_schema, output_schema=env_cfg.env.output_schema)

if args.export_onnx:
path = ppo_runner.load_path
embodiment = ppo_runner.cfg["experiment_name"].lower()
policy_cfg = ActorCfg(
embodiment=embodiment,
actor_cfg = ActorCfg(
embodiment=ppo_runner.cfg["experiment_name"].lower(),
cycle_time=env_cfg.rewards.cycle_time,
sim_dt=env_cfg.sim.dt,
sim_decimation=env_cfg.control.decimation,
Expand All @@ -104,14 +113,13 @@ def play(args: argparse.Namespace) -> None:
clip_observations=env_cfg.normalization.clip_observations,
clip_actions=env_cfg.normalization.clip_actions,
)
actor_model, sim2sim_info, input_tensors = get_actor_policy(path, policy_cfg)

# Merge policy_cfg and sim2sim_info into a single config object
export_config = {**vars(policy_cfg), **sim2sim_info}

export_to_onnx(actor_model, input_tensors=input_tensors, config=export_config, save_path="kinfer_policy.onnx")
print("Exported policy as kinfer-compatible onnx to: ", path)

jit_policy, metadata, _ = get_actor_policy(ppo_runner.load_path, actor_cfg)
WT-MM marked this conversation as resolved.
Show resolved Hide resolved
kinfer_policy = export_model(
model=jit_policy,
schema=model_schema,
metadata=metadata,
WT-MM marked this conversation as resolved.
Show resolved Hide resolved
)
onnx.save(kinfer_policy, "policy.kinfer")
# Prepare for logging
env_logger = Logger(env.dt)
robot_index = 0
Expand Down
3 changes: 2 additions & 1 deletion sim/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ numpy==1.21.6
wandb
tensorboard==2.14.0
onnxscript
onnx
mujoco==2.3.6
kinfer==0.0.5
opencv-python
opencv-python
Loading
Loading