diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..6d2b0a56 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index a2d526c2..ea1b5576 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,12 @@ venv .env __pycache__ +yolov8m-pose.pt +*.mp4 +ball/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib +venv/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib +ball/ +tmp/ +tmp/*.json .DS_Store .vscode/ -tmp/ diff --git a/data/training_data.mp4 b/data/training_data.mp4 index 280d9d5b..1428678f 100644 Binary files a/data/training_data.mp4 and b/data/training_data.mp4 differ diff --git a/requirements.txt b/requirements.txt index a947937e..97f15bde 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,8 +25,8 @@ opencv-python==4.7.0.72 matplotlib>=3.2.2 Pillow>=7.1.2 PyYAML>=5.3.1 -torch==2.0.1 -torchvision==0.15.2 +torch==2.0.1 # Check if this version exists, if not, use the latest stable version +torchvision==0.15.2 # Same check as torch tqdm>=4.41.0 seaborn scipy @@ -49,7 +49,13 @@ scikit-learn # View streamlit>=1.18.1 -hydralit_components>= 1.0.10 +hydralit_components>=1.0.10 # Misc -pylint \ No newline at end of file +pylint + +# Additional Dependencies for Pose Estimation +json5 +ultralytics +imageio==2.9.0 +imageio-ffmpeg>=0.4.3 \ No newline at end of file diff --git a/src/main.py b/src/main.py index e0a07db6..25f52b30 100644 --- a/src/main.py +++ b/src/main.py @@ -38,7 +38,8 @@ def main(video_path): modelrunner = ModelRunner(video_path, model_vars) modelrunner.run() - people_output, ball_output = modelrunner.fetch_output() + modelrunner.pose() + people_output, ball_output, pose_output = modelrunner.fetch_output() output_video_path = "tmp/court_video.mp4" output_video_path_reenc = "tmp/court_video_reenc.mp4" @@ -49,10 +50,11 @@ def main(video_path): output_video_path, output_video_path_reenc, ) + processrunner.run() results = processrunner.get_results() return results - + if __name__ == "__main__": import sys diff --git a/src/modelrunner.py b/src/modelrunner.py index 1432f83d..88a6f9c0 100644 --- a/src/modelrunner.py +++ b/src/modelrunner.py @@ -6,6 +6,8 @@ import pickle import subprocess from typing import Tuple +from pose_estimation.pose_estimate import PoseEstimator +from ultralytics import YOLO class ModelRunner: """ @@ -15,6 +17,7 @@ class ModelRunner: def __init__(self, video_path, model_vars) -> None: self.video_path = video_path self.frame_reduction_factor = model_vars['frame_reduction_factor'] + self.pose_estimator = PoseEstimator(video_path=video_path) def drop_frames(self, input_path) -> str: @@ -51,15 +54,24 @@ def run(self): with open('tmp/output.pickle', 'rb') as f: self.output_dict = pickle.load(f) - - def fetch_output(self) -> Tuple[str, str]: + def pose(self): + model = YOLO('src/pose_estimation/best.pt') + results = model( + source = self.video_path, + show=False, + conf=0.3, + verbose = False + ) + self.pose_estimator.estimate_pose(results = results) + + def fetch_output(self) -> Tuple[str, str, str]: """ Converts the people and ball model output in self.output.dict into txt files. Returns a tuple of the people and ball txt output paths. """ - ball_list = [tuple(round(num) for num in tup) + ball_list = [tuple(round(num) for num in tup) for tup in self.output_dict['basketball_data'][0]] - people_list = [tuple(round(num) for num in tup) + people_list = [tuple(round(num) for num in tup) for tup in self.output_dict['person_data'][0]] ball_data = [(' '.join(map(str, ball[0:7])) + ' -1 -1 -1 -1') for ball in ball_list] @@ -72,4 +84,4 @@ def fetch_output(self) -> Tuple[str, str]: with open('tmp/people.txt', 'w') as f: f.write('\n'.join(people_data)) - return 'tmp/people.txt', 'tmp/ball.txt' + return 'tmp/people.txt', 'tmp/ball.txt', 'tmp/pose.txt' diff --git a/src/pose_estimation/best.pt b/src/pose_estimation/best.pt new file mode 100644 index 00000000..33342b1f Binary files /dev/null and b/src/pose_estimation/best.pt differ diff --git a/src/pose_estimation/pose_estimate.py b/src/pose_estimation/pose_estimate.py new file mode 100644 index 00000000..0a5673e8 --- /dev/null +++ b/src/pose_estimation/pose_estimate.py @@ -0,0 +1,100 @@ +import torch +import math +import json +from ultralytics import YOLO + +class PoseEstimator: + def __init__(self, model_path='src/pose_estimation/best.pt', video_path='res/pose_results/test_multiple_people.mp4', combinations=None): + # Initialize paths, model, and combinations of keypoints to calculate angles + self.model_path = model_path + self.video_path = video_path + self.model = YOLO(model_path) # Load the YOLO model + + # Combinations of points to calculate 8 angles + self.combinations = combinations if combinations is not None else [ + (5, 7, 9), (6, 8, 10), (11, 13, 15), (12, 14, 16), + (5, 6, 8), (6, 5, 7), (11, 12, 14), (12, 11, 13) + ] + + # Names corresponding to the adjusted 8 angle types + self.angle_names = [ + "left_elbow", "right_elbow", "left_knee", "right_knee", + "right_shoulder", "left_shoulder", + "right_hip", "left_hip" + ] + + @staticmethod + def compute_angle(p1, p2, p3): + # Calculate angle given 3 points using the dot product and arc cosine + vector_a = p1 - p2 + vector_b = p3 - p2 + + # Normalize the vectors (to make them unit vectors) + vector_a = vector_a / torch.norm(vector_a) + vector_b = vector_b / torch.norm(vector_b) + + # Compute the angle + cosine_angle = torch.sum(vector_a * vector_b) + angle_radians = torch.acos(cosine_angle) + angle_degrees = angle_radians * 180 / math.pi + + return angle_degrees + + def estimate_pose(self, results): + model = YOLO(self.model_path) + + # Initialize an empty list to store pose data + pose_data = [] + + # empty list for shots + shots = [] + + for frame_idx, result in enumerate(results): + keypoints = result.keypoints.data[:, :, :2].numpy() # Extracting the (x, y) coordinates + confidences = result.keypoints.conf.numpy().tolist() # Extracting the confidences + boxes = result.boxes.xyxy.numpy().tolist() # Extracting bounding boxes + frame_pose_data = { + 'frame': frame_idx, + 'persons': [], + 'boxes': boxes, + 'keypoints': keypoints.tolist(), + 'confidences': confidences + } + + for person_idx, (person_keypoints, person_confidences, box) in enumerate(zip(keypoints, confidences, boxes)): + person_data = { + 'keypoints': person_keypoints.tolist(), + 'confidences': person_confidences, + 'box': box, + 'angles': {} + } + + for idx, combination in enumerate(self.combinations): + if all(idx < len(person_keypoints) for idx in combination): + p1, p2, p3 = (person_keypoints[i] for i in combination) + angle_degrees = self.compute_angle(torch.tensor(p1), torch.tensor(p2), torch.tensor(p3)) + person_data['angles'][self.angle_names[idx]] = angle_degrees.item() + + frame_pose_data['persons'].append(person_data) + + # naive check shot: if wrists above shoulders + left_wrist_y = person_keypoints[9][1] + right_wrist_y = person_keypoints[10][1] + left_shoulder_y = person_keypoints[5][1] + right_shoulder_y = person_keypoints[6][1] + + if left_wrist_y < left_shoulder_y and right_wrist_y < right_shoulder_y: + shots.append("SHOT TAKEN: person " + str(person_idx) + ", frame " + str(frame_idx)) + # end naive check shot + + pose_data.append(frame_pose_data) + + + with open("tmp/pose_data.json", "w") as f: + json.dump(pose_data, f) + + # write to shots file + with open('tmp/shots.txt', 'w') as f: + for line in shots: + f.write(line) + f.write('\n') \ No newline at end of file diff --git a/src/view/app.py b/src/view/app.py index e948244b..3565eca3 100644 --- a/src/view/app.py +++ b/src/view/app.py @@ -20,7 +20,7 @@ if "state" not in st.session_state: st.session_state.state = 0 st.session_state.logo = "src/view/static/basketball.png" - with open("data/training_data.mp4", "rb") as file: + with open("data/short_new_1.mp4", "rb") as file: st.session_state.video_file = io.BytesIO(file.read()) st.session_state.processed_video = None st.session_state.result_string = None @@ -30,11 +30,11 @@ def process_video(video_file): - ''' + """ Takes in a mp4 file at video_file and uploads it to the backend, then stores the processed video name into session state Temporarily: stores the processed video into tmp/user_upload.mp4 - ''' + """ if video_file is None: return False response = requests.post( @@ -112,14 +112,13 @@ def results_page(): st.download_button( label="Download Results", use_container_width=True, - data=st.session_state.result_string, + data=st.session_state.result_string, file_name="results.txt", ) st.button(label="Back to Home", on_click=change_state, args=(0,), type="primary") - def tips_page(): """ Loads tips page diff --git a/tmp/test_video.mp4 b/tmp/test_video.mp4 deleted file mode 100644 index bad1cdee..00000000 Binary files a/tmp/test_video.mp4 and /dev/null differ