Skip to content

Commit

Permalink
Merge branch 'ros-2-complete' of https://github.com/purdue-arc/rocket…
Browse files Browse the repository at this point in the history
…_league into ros-2-complete
  • Loading branch information
jcrm1 committed Feb 10, 2024
2 parents 989da0d + 09513e5 commit fa5fc06
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 21 deletions.
10 changes: 5 additions & 5 deletions src/rktl_autonomy/nodes/plotter
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Plotter(rclpy.Node):
#rospy.Subscriber('~log', DiagnosticStatus, self.progress_cb)
self.create_subscription(DiagnosticStatus, '~log', self.progress_cb, qos_profile=10)

self.history = None
self.history = []
self.LOG_NAME = None
self.next_plot_episode = self.PLOT_FREQ
self.init_plot()
Expand Down Expand Up @@ -110,15 +110,15 @@ class Plotter(rclpy.Node):
data[item.key] = float(item.value)

if data["episode"] is not None:
if self.history is None:
if self.history is []:
self.history = [data]
else:
self.history.append(data)

if data["episode"] >= self.next_plot_episode:
self.plot()
self.next_plot_episode += self.PLOT_FREQ
self.history = None
self.history = []
else:
#rospy.logerr("Bad progress message.")
self.get_logger().warn("Bad progress message.")
Expand Down Expand Up @@ -175,8 +175,8 @@ class Plotter(rclpy.Node):

# update file
#rospy.loginfo(f"Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
self.get_logger().info("Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
plt.savefig(self.LOG_DIR + self.LOG_NAME)
self.get_logger().info(f"Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
plt.savefig(self.LOG_DIR + str(self.LOG_NAME))

if __name__ == "__main__":
Plotter()
2 changes: 1 addition & 1 deletion src/rktl_autonomy/nodes/rocket_league_agent
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ env = RocketLeagueInterface(eval=True)

# load the model
# weights = expanduser(rospy.get_param('~weights'))
weights = expanduser(env.node.get_parameter('~weights'))
weights = expanduser(env.node.get_parameter('~weights').get_parameter_value().string_value)
model = PPO.load(weights)

# evaluate in real-time
Expand Down
2 changes: 1 addition & 1 deletion src/rktl_autonomy/rktl_autonomy/_ros_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from threading import Condition
import time, uuid, socket, os

from gym import Env
from gymnasium import Env

import rclpy
from rclpy.duration import Duration
Expand Down
12 changes: 6 additions & 6 deletions src/rktl_autonomy/rktl_autonomy/rocket_league_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@

# package
from rktl_autonomy._ros_interface import ROSInterface
from gym.spaces import Box, Discrete
from gymnasium.spaces import Box, Discrete

# ROS
# import rospy
import rclpy
from rclpy import Node
from rclpy.parameter import Parameter
Expand All @@ -25,7 +24,7 @@

# System
import numpy as np
from tf.transformations import euler_from_quaternion
from transformations import euler_from_quaternion
from enum import IntEnum, unique, auto
from math import pi, tan

Expand Down Expand Up @@ -143,7 +142,7 @@ def __init__(self, eval=False, launch_file=('rktl_autonomy', 'rocket_league_trai
self.node = Node('rocket_league_interface')
# Publishers
# self._command_pub = rospy.Publisher('cars/car0/command', ControlCommand, queue_size=1)
self.node.create_publisher(ControlCommand, 'cars/car0/command', 1)
self._command_pub = self.node.create_publisher(ControlCommand, 'cars/car0/command', 1)
# self._reset_srv = rospy.ServiceProxy('sim_reset', Empty)
self._reset_srv = self.node.create_client(Empty, 'sim_reset')

Expand Down Expand Up @@ -256,14 +255,15 @@ def _get_state(self):
goal_dist_sq = np.sum(np.square(ball[0:2] - np.array([self._FIELD_LENGTH/2, 0])))
reward += self._GOAL_DISTANCE_REWARD * goal_dist_sq

if self._score != 0:
if self._score is not None and self._score != 0:
done = True
if self._score > 0:
reward += self._WIN_REWARD
else:
reward += self._LOSS_REWARD

x, y, __, v, __ = self._car_odom
if self._car_odom is not None:
x, y, __, v, __ = self._car_odom

if self._prev_vel is None:
self._prev_vel = v
Expand Down
2 changes: 1 addition & 1 deletion src/rktl_autonomy/scripts/eval_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
Expand Down
4 changes: 2 additions & 2 deletions src/rktl_autonomy/scripts/train_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -39,7 +39,7 @@ def train(n_envs=24, n_saves=100, n_steps=240000000, env_number=0):

# log model weights
freq = n_steps / (n_saves * n_envs)
callback = CheckpointCallback(save_freq=freq, save_path=log_dir)
callback = CheckpointCallback(save_freq=int(freq), save_path=log_dir)

# run training
steps = n_steps
Expand Down
2 changes: 1 addition & 1 deletion src/rktl_autonomy/scripts/tune_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
import numpy as np
from stable_baselines3 import PPO

Expand Down
6 changes: 4 additions & 2 deletions src/rktl_control/launch/ball.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
def generate_launch_description():
ld = launch.LaunchDescription([
launch_ros.actions.Node(
namespace='ball',
package='rktl_control',
executable='mean_odom_filter',
name='mean_odom_filter',
output='screen',
parameters=[
get_package_share_directory(
'rktl_control') + '/config/mean_odom_filter.yaml'
{
launch.substitutions.PathJoinSubstitution(launch_ros.substitutions.FindPackageShare('rktl_control'), '/config/mean_odom_filter.yaml')
}
]
)
])
Expand Down
49 changes: 49 additions & 0 deletions src/rktl_control/launch/car.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,55 @@ def generate_launch_description():
launch.actions.DeclareLaunchArgument(
name='use_particle_filter',
default_value='true'
),
launch.actions.GroupAction(
actions=[
launch_ros.actions.PushRosNamespace("cars/" + launch.substitutions.LaunchConfiguration("car_name")),

launch_ros.actions.Node(
package='rktl_control',
executable='particle_odom_filter',
name='particle_odom_filter',
output='screen',
condition=launch.conditions.LaunchConfigurationEquals('use_particle_filter', True),
parameters=[
{
launch.substitutions.PathJoinSubstitution(launch_ros.substitutions.FindPackageShare('rktl_control'), '/config/particle_odom_filter.yaml')
},
{
'frame_ids/body': launch.substitutions.LaunchConfiguration('car_name')
}
]
),

launch_ros.actions.Node(
package='rktl_control',
executable='mean_odom_filter',
name='mean_odom_filter',
output='screen',
condition=launch.conditions.LaunchConfigurationNotEquals('use_particle_filter', True),
parameters=[
{
launch.substitutions.PathJoinSubstitution(launch_ros.substitutions.FindPackageShare('rktl_control'), '/config/mean_odom_filter.yaml')
},
{
'frame_ids/body': launch.substitutions.LaunchConfiguration('car_name')
}
]
),

launch_ros.actions.Node(
package='rktl_control',
executable='controller',
name='controller',
output='screen',
parameters=[
{
launch.substitutions.PathJoinSubstitution(launch_ros.substitutions.FindPackageShare('rktl_control'), '/config/controller.yaml')
},
]
)
]
)
])
return ld
Expand Down
5 changes: 3 additions & 2 deletions src/rktl_control/launch/hardware_interface.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def generate_launch_description():
executable='serial_node.py',
name='hardware_interface',
parameters=[
get_package_share_directory(
'rktl_control') + '/config/hardware_interface.yaml'
{
launch.substitutions.PathJoinSubstitution(launch_ros.substitutions.FindPackageShare('rktl_control'), '/config/hardware_interface.yaml')
}
]
)
])
Expand Down
1 change: 1 addition & 0 deletions src/rktl_control/launch/keyboard_control.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def generate_launch_description():
default_value='car0'
),
launch_ros.actions.Node(
namespace=launch.substitutions.PathJoinSubstitution('cars/', launch_ros.substitutions.FindPackageShare('car_name')),
package='rktl_control',
executable='keyboard_interface',
name='keyboard_interface',
Expand Down
47 changes: 47 additions & 0 deletions src/rktl_control/launch/xbox_control.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,54 @@ def generate_launch_description():
launch.actions.DeclareLaunchArgument(
name='delay',
default_value='0.1'
),
launch.actions.GroupAction(
actions=[
launch_ros.actions.PushRosNamespace("cars/" + launch.substitutions.LaunchConfiguration("car_name")),

launch_ros.actions.Node(
package='joy',
executable='joy_node',
name='joy_node',
output='screen',
parameters=[
{
'dev': launch.substitutions.LaunchConfiguration('device')
},
{
'default_trig_val': 'true'
}
]
),
launch_ros.actions.Node(
package='rktl_control',
executable='xbox_interface',
name='xbox_interface',
output='screen',
parameters=[
{
'base_throttle': '0.75'
},
{
'boost_throttle': '1.25'
},
{
'cooldown_ratio': '3'
},
{
'max_boost': '2'
}
],
actions=[
launch_ros.actions.SetRemap(
src="joy",
dst="joy)mux",
)
]
)
]
)

])
return ld

Expand Down

0 comments on commit fa5fc06

Please sign in to comment.