Skip to content

Commit

Permalink
refactor: typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Orbital-Web committed Dec 6, 2023
1 parent 673b3d2 commit 8384fed
Show file tree
Hide file tree
Showing 20 changed files with 55 additions and 86 deletions.
4 changes: 2 additions & 2 deletions armour/ArmourGoal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
from trimesh import Trimesh
from typing import OrderedDict
from nptyping import NDArray
from math import pi
from rtd.util.mixins.Typings import Matnp

# define top level module logger
import logging
Expand Down Expand Up @@ -79,7 +79,7 @@ def reset(self, **options):
def create_plot_data(self, time: float = None) -> list[Actor]:
# generate mesh
config = self.goal_position
fk: OrderedDict[Trimesh, NDArray] = self.arm_agent.info.robot.visual_trimesh_fk(cfg=config)
fk: OrderedDict[Trimesh, Matnp] = self.arm_agent.info.robot.visual_trimesh_fk(cfg=config)
meshes = [mesh.copy().apply_transform(transform) for mesh, transform in fk.items()]

self.plot_data: list[Actor] = list()
Expand Down
6 changes: 3 additions & 3 deletions armour/agent/ArmourAgentCollision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from armour.agent import ArmourAgentInfo, ArmourAgentState
import numpy as np
from collections import OrderedDict
from nptyping import NDArray
from rtd.util.mixins.Typings import Matnp



Expand All @@ -25,7 +25,7 @@ def reset(self):
pass


def getCollisionObject(self, q: NDArray = None, time: float = None) -> CollisionObject:
def getCollisionObject(self, q: Matnp = None, time: float = None) -> CollisionObject:
'''
Generates a CollisionObject for a given time `time` or
configuration `q` (only one or none must be provided)
Expand All @@ -36,7 +36,7 @@ def getCollisionObject(self, q: NDArray = None, time: float = None) -> Collision
elif time is not None and q is None:
config = self.arm_state.get_state(np.array([time])).position # position at given time

fk: OrderedDict[Trimesh, NDArray] = self.arm_info.robot.collision_trimesh_fk(cfg=config)
fk: OrderedDict[Trimesh, Matnp] = self.arm_info.robot.collision_trimesh_fk(cfg=config)
meshes = [mesh.copy().apply_transform(transform) for mesh, transform in fk.items()]
return CollisionObject(meshes, id(self.arm_info))

Expand Down
7 changes: 3 additions & 4 deletions armour/agent/ArmourAgentVisual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import numpy as np
from trimesh import Trimesh
from typing import OrderedDict
from nptyping import NDArray

from rtd.util.mixins.Typings import Matnp


class ArmourAgentVisual(PyvistaVisualObject, Options):
Expand Down Expand Up @@ -56,7 +55,7 @@ def create_plot_data(self, time: float = None) -> list[Actor]:

# generate mesh
config = self.arm_state.get_state(np.array([time])).position
fk: OrderedDict[Trimesh, NDArray] = self.arm_info.robot.visual_trimesh_fk(cfg=config)
fk: OrderedDict[Trimesh, Matnp] = self.arm_info.robot.visual_trimesh_fk(cfg=config)
meshes = [mesh.copy().apply_transform(transform) for mesh, transform in fk.items()]

self.plot_data: list[Actor] = list()
Expand Down Expand Up @@ -87,7 +86,7 @@ def plot(self, time: float = None):

# generate mesh
config = self.arm_state.get_state(np.array([time])).position
fk: OrderedDict[Trimesh, NDArray] = self.arm_info.robot.visual_trimesh_fk(cfg=config)
fk: OrderedDict[Trimesh, Matnp] = self.arm_info.robot.visual_trimesh_fk(cfg=config)
meshes = [mesh.copy().apply_transform(transform) for mesh, transform in fk.items()]

for actor, mesh in zip(self.plot_data, meshes):
Expand Down
1 change: 0 additions & 1 deletion armour/agent/ArmourIdealController.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from armour.agent import ArmourAgentInfo, ArmourAgentState
from armour.trajectory import ZeroHoldArmTrajectory
import numpy as np
from nptyping import NDArray

# define top level module logger
import logging
Expand Down
8 changes: 4 additions & 4 deletions armour/legacy/StraightLineHLP.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from armour.agent import ArmourAgentInfo
from nptyping import NDArray
import numpy as np
from rtd.util.mixins.Typings import Vecnp



class StraightLineHLP():
def __init__(self):
self.default_lookahead_distance: float = 1
self.goal: NDArray = None
self.joint_state_indices: NDArray = None
self.goal: Vecnp = None
self.joint_state_indices: Vecnp = None


def setup(self, agent_info: ArmourAgentInfo, world_info: dict):
self.goal = world_info["goal"]
self.arm_joint_state_indices = agent_info.joint_state_indices


def get_waypoint(self, state: NDArray, lookahead_distance: float = None):
def get_waypoint(self, state: Vecnp, lookahead_distance: float = None) -> Vecnp:
if lookahead_distance is None:
lookahead_distance = self.default_lookahead_distance
q_cur = state[self.arm_joint_state_indices]
Expand Down
3 changes: 2 additions & 1 deletion armour/legacy/bernstein_to_poly.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from math import comb
import numpy as np
from rtd.util.mixins.Typings import Vec, Vecnp



def bernstein_to_poly(beta: list[float], n: int):
def bernstein_to_poly(beta: Vec, n: int) -> Vecnp:
'''
converts bernstein polynomial coefficients to
monomial coefficients
Expand Down
6 changes: 5 additions & 1 deletion armour/legacy/match_deg5_bernstein_coefficients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
def match_deg5_bernstein_coefficients(traj_constraints: list[float], T: float = 1):
from rtd.util.mixins.Typings import Vec



def match_deg5_bernstein_coefficients(traj_constraints: Vec, T: float = 1) -> Vec:
'''
match coefficients to initial position, velocity, acceleration (t=0)
and final position, velocity, and acceleration (t=1)
Expand Down
7 changes: 2 additions & 5 deletions armour/trajectory/ArmTrajectoryFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from rtd.planner.trajopt import TrajOptProps
from armour.reachsets import JRSInstance
from armour.trajectory import PiecewiseArmTrajectory, ZeroHoldArmTrajectory, BernsteinArmTrajectory
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand All @@ -21,7 +18,7 @@ def __init__(self, trajOptProps: TrajOptProps, traj_type: str = "piecewise"):


def createTrajectory(self, robotState: EntityState, rsInstances: dict[str, ReachSetInstance] = None,
trajectoryParams: RowVec = None, jrsInstance: JRSInstance = None,
trajectoryParams: Vecnp = None, jrsInstance: JRSInstance = None,
traj_type: str = None) -> ZeroHoldArmTrajectory | PiecewiseArmTrajectory | BernsteinArmTrajectory:
'''
Create a new trajectory object for the given state
Expand Down
9 changes: 3 additions & 6 deletions armour/trajectory/BernsteinArmTrajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
from armour.reachsets import JRSInstance
from armour.legacy import bernstein_to_poly, match_deg5_bernstein_coefficients
import numpy as np
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand All @@ -31,7 +28,7 @@ def __init__(self, trajOptProps: TrajOptProps, startState: ArmRobotState, jrsIns
self.jrsInstance = jrsInstance


def setParameters(self, trajectoryParams: RowVec, startState: ArmRobotState = None,
def setParameters(self, trajectoryParams: Vecnp, startState: ArmRobotState = None,
jrsInstance: JRSInstance = None):
'''
A validated method to set the parameters for the trajectory
Expand Down Expand Up @@ -97,7 +94,7 @@ def internalUpdate(self):
self.q_end = q_goal


def getCommand(self, time: RowVec) -> EntityState:
def getCommand(self, time: Vecnp) -> EntityState:
# Do a parameter check and time check, and throw if anything is
# invalid.
self.validate(throwOnError=True)
Expand Down
9 changes: 3 additions & 6 deletions armour/trajectory/PieceWiseArmTrajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from armour.reachsets import JRSInstance
from rtd.functional.vectools import rescale
import numpy as np
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vec, Mat, Vecnp, Matnp, Bound, Bounds, Boundsnp



Expand Down Expand Up @@ -38,7 +35,7 @@ class to update all other internal parameters once fully
self.q_end: float = None


def setParameters(self, trajectoryParams: RowVec, startState: ArmRobotState = None,
def setParameters(self, trajectoryParams: Vecnp, startState: ArmRobotState = None,
jrsInstance: JRSInstance = None):
'''
Set the parameters of the trajectory, with a focus on the
Expand Down Expand Up @@ -97,7 +94,7 @@ def internalUpdate(self):
+ 0.5*self.q_ddot_to_stop*self.trajOptProps.planTime**2)


def getCommand(self, time: RowVec) -> ArmRobotState:
def getCommand(self, time: Vecnp) -> ArmRobotState:
'''
Computes the actual input commands for the given time.
throws InvalidTrajectory if the trajectory isn't set
Expand Down
9 changes: 3 additions & 6 deletions armour/trajectory/ZeroHoldArmTrajectory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from rtd.planner.trajectory import Trajectory, InvalidTrajectory
from rtd.entity.states import ArmRobotState
import numpy as np
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand All @@ -23,7 +20,7 @@ class to update all other internal parameters once fully
self.startState = startState


def setParameters(self, trajectoryParams: RowVec, startState: ArmRobotState = None):
def setParameters(self, trajectoryParams: Vecnp, startState: ArmRobotState = None):
'''
Set the parameters of the trajectory, with a focus on the
parameters as the state should be set from the constructor
Expand All @@ -46,7 +43,7 @@ def validate(self, throwOnError: bool = False) -> bool:
return valid


def getCommand(self, time: RowVec) -> ArmRobotState:
def getCommand(self, time: Vecnp) -> ArmRobotState:
'''
Computes the actual input commands for the given time.
throws InvalidTrajectory if the trajectory isn't set
Expand Down
7 changes: 2 additions & 5 deletions rtd/planner/reachsets/ReachSetInstance.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from abc import ABCMeta, abstractmethod
from rtd.sim.world import WorldState
from typing import Callable
from nptyping import NDArray, Shape, Float64

# type hinting
BoundsVec = NDArray[Shape['N,2'], Float64]
from rtd.util.mixins.Typings import Boundsnp



Expand All @@ -21,7 +18,7 @@ class ReachSetInstance(metaclass=ABCMeta):
def __init__(self):
# A 2-column vector denoting the input minimum and maximums for the
# reachable set on the left and right, respectively
self.input_range: BoundsVec = None
self.input_range: Boundsnp = None

# The number of main shared parameters used by this set. Generally,
# this should match the size of the final trajectory parameters
Expand Down
11 changes: 4 additions & 7 deletions rtd/planner/trajectory/Trajectory.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from abc import ABCMeta, abstractmethod
from rtd.entity.states import EntityState
from rtd.planner.trajopt import TrajOptProps
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand All @@ -23,7 +20,7 @@ def __init__(self):
self.trajOptProps: TrajOptProps = None

# The parameters used for this trajectory
self.trajectoryParams: RowVec = None
self.trajectoryParams: Vecnp = None

# The time at which this trajectory is valid
self.startTime: float = None
Expand Down Expand Up @@ -52,7 +49,7 @@ def validate(self, throwOnError: bool = False) -> bool:


@abstractmethod
def setParameters(self, trajectoryParams: RowVec, **options):
def setParameters(self, trajectoryParams: Vecnp, **options):
'''
Set the parameters for the trajectory
Expand All @@ -67,7 +64,7 @@ def setParameters(self, trajectoryParams: RowVec, **options):


@abstractmethod
def getCommand(self, time: float | RowVec) -> EntityState:
def getCommand(self, time: float | Vecnp) -> EntityState:
'''
Computes the actual state to track for the given time
Expand Down
1 change: 0 additions & 1 deletion rtd/planner/trajectory/TrajectoryContainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from rtd.entity.states import EntityState
from rtd.functional.sequences import toSequence
import numpy as np
from nptyping import NDArray


class BadTrajectoryException(Exception):
Expand Down
7 changes: 2 additions & 5 deletions rtd/planner/trajectory/TrajectoryFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from rtd.planner.trajopt import TrajOptProps
from rtd.planner.trajectory import Trajectory
from rtd.planner.reachsets import ReachSetInstance
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand All @@ -25,7 +22,7 @@ def __init__(self):

@abstractmethod
def createTrajectory(self, robotState: EntityState, rsInstances: dict[str, ReachSetInstance] = None,
trajectoryParams: RowVec = None, **options) -> Trajectory:
trajectoryParams: Vecnp = None, **options) -> Trajectory:
'''
Factory method to create the trajectory
Expand Down
7 changes: 2 additions & 5 deletions rtd/planner/trajopt/GenericArmObjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from rtd.planner.trajopt import Objective, TrajOptProps
from rtd.planner.trajectory import TrajectoryFactory, Trajectory
import numpy as np
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand Down Expand Up @@ -43,7 +40,7 @@ def genObjective(self, robotState: EntityState, waypoint, reachableSets: dict[st


@staticmethod
def evalTrajectory(trajectoryParams: RowVec, trajectoryObj: Trajectory, q_des, t_cost: float | RowVec) -> float:
def evalTrajectory(trajectoryParams: Vecnp, trajectoryObj: Trajectory, q_des, t_cost: float | Vecnp) -> float:
'''
Helper function purely accessible to this class without any class state
which a handle can be made to to evaluate the trajectory for the cost.
Expand Down
9 changes: 3 additions & 6 deletions rtd/planner/trajopt/OptimizationEngine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from abc import ABCMeta, abstractmethod
from typing import Callable
from nptyping import NDArray, Shape, Float64

# type hinting
RowVec = NDArray[Shape['N'], Float64]
from rtd.util.mixins.Typings import Vecnp



Expand All @@ -12,8 +9,8 @@ class OptimizationEngine(metaclass=ABCMeta):
Base class for any sort of nonlinear optimizer used
'''
@abstractmethod
def performOptimization(self, initialGuess: RowVec, objectiveCallback: Callable,
constraintCallback: Callable, bounds: dict) -> tuple[bool, RowVec, float]:
def performOptimization(self, initialGuess: Vecnp, objectiveCallback: Callable,
constraintCallback: Callable, bounds: dict) -> tuple[bool, Vecnp, float]:
'''
Use the given optimizer to perform the optimization
RowVector
Expand Down
Loading

0 comments on commit 8384fed

Please sign in to comment.