-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
60065c1
commit a859f50
Showing
5 changed files
with
212 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Meta Platforms, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import os | ||
from typing import Dict, List | ||
|
||
import numpy as np | ||
import pytest | ||
from spot_rl.envs.place_env import PlaceController, construct_config_for_place | ||
from spot_rl.utils.utils import get_waypoint_yaml, place_target_from_waypoint | ||
from spot_wrapper.spot import Spot | ||
|
||
hardware_tests_dir = os.path.dirname(os.path.abspath(__file__)) | ||
test_configs_dir = os.path.join(hardware_tests_dir, "configs") | ||
test_data_dir = os.path.join(hardware_tests_dir, "data") | ||
test_nav_trajectories_dir = os.path.join(test_data_dir, "nav_trajectories") | ||
test_square_nav_trajectories_dir = os.path.join( | ||
test_nav_trajectories_dir, "square_of_side_200cm" | ||
) | ||
TEST_WAYPOINTS_YAML = os.path.join(test_configs_dir, "waypoints.yaml") | ||
TEST_CONFIGS_YAML = os.path.join(test_configs_dir, "config.yaml") | ||
|
||
|
||
def init_config(): | ||
""" | ||
Initialize config object for Nav test | ||
Returns: | ||
config: Config object | ||
""" | ||
config = construct_config_for_place(file_path=TEST_CONFIGS_YAML, opts=[]) | ||
|
||
return config | ||
|
||
|
||
def test_place(): | ||
config = init_config() | ||
test_waypoints_yaml_dict = get_waypoint_yaml(waypoint_file=TEST_WAYPOINTS_YAML) | ||
|
||
test_waypoints = [ | ||
"test_cup", | ||
"test_plush_lion", | ||
"test_plush_ball", | ||
] | ||
test_place_targets_list = [ | ||
place_target_from_waypoint(test_waypoint, test_waypoints_yaml_dict) | ||
for test_waypoint in test_waypoints | ||
] | ||
|
||
test_spot = Spot("PlaceEnvHardwareTest") | ||
test_DATA = None # TODO: Rename to something useful after defining return type of the method | ||
with test_spot.get_lease(hijack=True): | ||
place_controller = PlaceController(config=config, spot=test_spot) | ||
|
||
try: | ||
test_DATA = place_controller.execute( | ||
place_target_list=test_place_targets_list | ||
) | ||
except Exception: | ||
pytest.fails( | ||
"Pytest raised an error while executing PlaceController.execute from test_place_env.py" | ||
) | ||
finally: | ||
place_controller.shutdown(should_dock=True) | ||
|
||
assert test_DATA is not [] | ||
assert len(test_DATA) == len(test_waypoints) | ||
|
||
# ref_traj_set = load_json_files(test_square_nav_trajectories_dir) | ||
|
||
# avg_time_list_traj, std_time_list_traj = compute_avg_and_std_time(ref_traj_set) | ||
# avg_steps_list_traj, std_steps_list_traj = compute_avg_and_std_steps(ref_traj_set) | ||
# ( | ||
# test_pose_list, | ||
# test_time_list, | ||
# test_steps_list, | ||
# ) = extract_goal_poses_timestamps_steps_from_traj(test_traj) | ||
|
||
# print(f"Dataset: Avg. time to reach each waypoint - {avg_time_list_traj}") | ||
# print(f"Dataset: Std.Dev in time reach each waypoint - {std_time_list_traj}") | ||
# print(f"Test-Nav: Time taken to reach each waypoint - {test_time_list}\n") | ||
|
||
# print(f"Dataset: Avg. steps to reach each waypoint - {avg_steps_list_traj}") | ||
# print(f"Dataset: Std.Dev in steps to reach each waypoint - {std_steps_list_traj}") | ||
# print(f"Test-Nav: Steps taken to reach each waypoint - {test_steps_list}\n") | ||
|
||
# print(f"Test-Nav: Pose at each of the goal waypoint - {test_pose_list}\n") | ||
|
||
# allowable_std_dev_in_time = 3.0 | ||
# allowable_std_dev_in_steps = 3.0 | ||
# for wp_idx in range(len(test_waypoints)): | ||
# # Capture target pose for each waypoint | ||
# target_pose = list( | ||
# nav_target_from_waypoint(test_waypoints[wp_idx], test_waypoints_yaml_dict) | ||
# ) | ||
# target_pose[-1] = np.rad2deg(target_pose[-1]) | ||
# # Test that robot reached its goal successfully spatially | ||
# assert ( | ||
# is_pose_within_bounds( | ||
# test_pose_list[wp_idx], | ||
# target_pose, | ||
# config.SUCCESS_DISTANCE, | ||
# config.SUCCESS_ANGLE_DIST, | ||
# ) | ||
# is True | ||
# ) | ||
|
||
# # Test that robot reached its goal successfully temporally (within 1 std dev of mean) | ||
# assert ( | ||
# abs(test_time_list[wp_idx] - avg_time_list_traj[wp_idx]) | ||
# < allowable_std_dev_in_time * std_time_list_traj[wp_idx] | ||
# ) | ||
|
||
# # Test that test trajectory took similar amount of steps to finish execution | ||
# assert ( | ||
# abs(test_steps_list[wp_idx] - avg_steps_list_traj[wp_idx]) | ||
# < allowable_std_dev_in_steps * std_steps_list_traj[wp_idx] | ||
# ) | ||
|
||
# # Report DTW scores | ||
# dtw_score_list = compute_dtw_scores(test_traj, ref_traj_set) | ||
# print(f"DTW scores: {dtw_score_list}") |