Skip to content

Commit

Permalink
Merge pull request #15 from CornellDataScience/pose-estimate
Browse files Browse the repository at this point in the history
Pose estimate
  • Loading branch information
Mikonooooo authored Oct 25, 2023
2 parents 1df5d4a + e5e1445 commit 8e4e174
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 17 deletions.
Binary file added .DS_Store
Binary file not shown.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
venv
.env
__pycache__
yolov8m-pose.pt
*.mp4
ball/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib
venv/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib
ball/
tmp/
tmp/*.json
.DS_Store
.vscode/
tmp/
Binary file modified data/training_data.mp4
Binary file not shown.
14 changes: 10 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ opencv-python==4.7.0.72
matplotlib>=3.2.2
Pillow>=7.1.2
PyYAML>=5.3.1
torch==2.0.1
torchvision==0.15.2
torch==2.0.1 # Check if this version exists, if not, use the latest stable version
torchvision==0.15.2 # Same check as torch
tqdm>=4.41.0
seaborn
scipy
Expand All @@ -49,7 +49,13 @@ scikit-learn

# View
streamlit>=1.18.1
hydralit_components>= 1.0.10
hydralit_components>=1.0.10

# Misc
pylint
pylint

# Additional Dependencies for Pose Estimation
json5
ultralytics
imageio==2.9.0
imageio-ffmpeg>=0.4.3
6 changes: 4 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def main(video_path):

modelrunner = ModelRunner(video_path, model_vars)
modelrunner.run()
people_output, ball_output = modelrunner.fetch_output()
modelrunner.pose()
people_output, ball_output, pose_output = modelrunner.fetch_output()
output_video_path = "tmp/court_video.mp4"
output_video_path_reenc = "tmp/court_video_reenc.mp4"

Expand All @@ -49,10 +50,11 @@ def main(video_path):
output_video_path,
output_video_path_reenc,
)

processrunner.run()
results = processrunner.get_results()
return results


if __name__ == "__main__":
import sys
Expand Down
22 changes: 17 additions & 5 deletions src/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pickle
import subprocess
from typing import Tuple
from pose_estimation.pose_estimate import PoseEstimator
from ultralytics import YOLO

class ModelRunner:
"""
Expand All @@ -15,6 +17,7 @@ class ModelRunner:
def __init__(self, video_path, model_vars) -> None:
self.video_path = video_path
self.frame_reduction_factor = model_vars['frame_reduction_factor']
self.pose_estimator = PoseEstimator(video_path=video_path)


def drop_frames(self, input_path) -> str:
Expand Down Expand Up @@ -51,15 +54,24 @@ def run(self):
with open('tmp/output.pickle', 'rb') as f:
self.output_dict = pickle.load(f)


def fetch_output(self) -> Tuple[str, str]:
def pose(self):
model = YOLO('src/pose_estimation/best.pt')
results = model(
source = self.video_path,
show=False,
conf=0.3,
verbose = False
)
self.pose_estimator.estimate_pose(results = results)

def fetch_output(self) -> Tuple[str, str, str]:
"""
Converts the people and ball model output in self.output.dict into txt files.
Returns a tuple of the people and ball txt output paths.
"""
ball_list = [tuple(round(num) for num in tup)
ball_list = [tuple(round(num) for num in tup)
for tup in self.output_dict['basketball_data'][0]]
people_list = [tuple(round(num) for num in tup)
people_list = [tuple(round(num) for num in tup)
for tup in self.output_dict['person_data'][0]]
ball_data = [(' '.join(map(str, ball[0:7])) + ' -1 -1 -1 -1')
for ball in ball_list]
Expand All @@ -72,4 +84,4 @@ def fetch_output(self) -> Tuple[str, str]:
with open('tmp/people.txt', 'w') as f:
f.write('\n'.join(people_data))

return 'tmp/people.txt', 'tmp/ball.txt'
return 'tmp/people.txt', 'tmp/ball.txt', 'tmp/pose.txt'
Binary file added src/pose_estimation/best.pt
Binary file not shown.
100 changes: 100 additions & 0 deletions src/pose_estimation/pose_estimate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import torch
import math
import json
from ultralytics import YOLO

class PoseEstimator:
def __init__(self, model_path='src/pose_estimation/best.pt', video_path='res/pose_results/test_multiple_people.mp4', combinations=None):
# Initialize paths, model, and combinations of keypoints to calculate angles
self.model_path = model_path
self.video_path = video_path
self.model = YOLO(model_path) # Load the YOLO model

# Combinations of points to calculate 8 angles
self.combinations = combinations if combinations is not None else [
(5, 7, 9), (6, 8, 10), (11, 13, 15), (12, 14, 16),
(5, 6, 8), (6, 5, 7), (11, 12, 14), (12, 11, 13)
]

# Names corresponding to the adjusted 8 angle types
self.angle_names = [
"left_elbow", "right_elbow", "left_knee", "right_knee",
"right_shoulder", "left_shoulder",
"right_hip", "left_hip"
]

@staticmethod
def compute_angle(p1, p2, p3):
# Calculate angle given 3 points using the dot product and arc cosine
vector_a = p1 - p2
vector_b = p3 - p2

# Normalize the vectors (to make them unit vectors)
vector_a = vector_a / torch.norm(vector_a)
vector_b = vector_b / torch.norm(vector_b)

# Compute the angle
cosine_angle = torch.sum(vector_a * vector_b)
angle_radians = torch.acos(cosine_angle)
angle_degrees = angle_radians * 180 / math.pi

return angle_degrees

def estimate_pose(self, results):
model = YOLO(self.model_path)

# Initialize an empty list to store pose data
pose_data = []

# empty list for shots
shots = []

for frame_idx, result in enumerate(results):
keypoints = result.keypoints.data[:, :, :2].numpy() # Extracting the (x, y) coordinates
confidences = result.keypoints.conf.numpy().tolist() # Extracting the confidences
boxes = result.boxes.xyxy.numpy().tolist() # Extracting bounding boxes
frame_pose_data = {
'frame': frame_idx,
'persons': [],
'boxes': boxes,
'keypoints': keypoints.tolist(),
'confidences': confidences
}

for person_idx, (person_keypoints, person_confidences, box) in enumerate(zip(keypoints, confidences, boxes)):
person_data = {
'keypoints': person_keypoints.tolist(),
'confidences': person_confidences,
'box': box,
'angles': {}
}

for idx, combination in enumerate(self.combinations):
if all(idx < len(person_keypoints) for idx in combination):
p1, p2, p3 = (person_keypoints[i] for i in combination)
angle_degrees = self.compute_angle(torch.tensor(p1), torch.tensor(p2), torch.tensor(p3))
person_data['angles'][self.angle_names[idx]] = angle_degrees.item()

frame_pose_data['persons'].append(person_data)

# naive check shot: if wrists above shoulders
left_wrist_y = person_keypoints[9][1]
right_wrist_y = person_keypoints[10][1]
left_shoulder_y = person_keypoints[5][1]
right_shoulder_y = person_keypoints[6][1]

if left_wrist_y < left_shoulder_y and right_wrist_y < right_shoulder_y:
shots.append("SHOT TAKEN: person " + str(person_idx) + ", frame " + str(frame_idx))
# end naive check shot

pose_data.append(frame_pose_data)


with open("tmp/pose_data.json", "w") as f:
json.dump(pose_data, f)

# write to shots file
with open('tmp/shots.txt', 'w') as f:
for line in shots:
f.write(line)
f.write('\n')
9 changes: 4 additions & 5 deletions src/view/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
if "state" not in st.session_state:
st.session_state.state = 0
st.session_state.logo = "src/view/static/basketball.png"
with open("data/training_data.mp4", "rb") as file:
with open("data/short_new_1.mp4", "rb") as file:
st.session_state.video_file = io.BytesIO(file.read())
st.session_state.processed_video = None
st.session_state.result_string = None
Expand All @@ -30,11 +30,11 @@


def process_video(video_file):
'''
"""
Takes in a mp4 file at video_file and uploads it to the backend, then stores
the processed video name into session state
Temporarily: stores the processed video into tmp/user_upload.mp4
'''
"""
if video_file is None:
return False
response = requests.post(
Expand Down Expand Up @@ -112,14 +112,13 @@ def results_page():
st.download_button(
label="Download Results",
use_container_width=True,
data=st.session_state.result_string,
data=st.session_state.result_string,
file_name="results.txt",
)

st.button(label="Back to Home", on_click=change_state, args=(0,), type="primary")



def tips_page():
"""
Loads tips page
Expand Down
Binary file removed tmp/test_video.mp4
Binary file not shown.

0 comments on commit 8e4e174

Please sign in to comment.