Skip to content

Commit

Permalink
Merge branch 'ball-202' into yolov7
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikonooooo committed Oct 25, 2023
2 parents be394e9 + 8e4e174 commit 2817818
Show file tree
Hide file tree
Showing 28 changed files with 1,587 additions and 1,182 deletions.
Binary file added .DS_Store
Binary file not shown.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
run: echo $DATA >> /home/runner/work/Ball-101/Ball-101/.env
- name: Run CI tests
run: |
python src/processing/shot_detect.py
python src/main.py data/training_data.mp4
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
venv
.env
__pycache__
yolov8m-pose.pt
*.mp4
ball/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib
venv/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib
ball/
tmp/
tmp/*.json
.DS_Store
.vscode/
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ Enable AWS connection by pasting the .env file into the repo.

Start the server backend by running
```
cd src/api
uvicorn backend:app --reload
uvicorn src.api/.backend:app --reload
```

Open a new bash terminal and start the frontend by running
Expand Down
Binary file added data/stable_jerry.mp4
Binary file not shown.
Binary file modified data/training_data.mp4
Binary file not shown.
Binary file modified data/true_map.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 13 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ opencv-python==4.7.0.72
matplotlib>=3.2.2
Pillow>=7.1.2
PyYAML>=5.3.1
torch==2.0.1
torchvision==0.15.2
torch==2.0.1 # Check if this version exists, if not, use the latest stable version
torchvision==0.15.2 # Same check as torch
tqdm>=4.41.0
seaborn
scipy
Expand All @@ -44,9 +44,18 @@ yapf
isort==4.3.21
imageio

# Processing
scikit-learn

# View
streamlit>=1.18.1
hydralit_components>= 1.0.10
hydralit_components>=1.0.10

# Misc
pylint
pylint

# Additional Dependencies for Pose Estimation
json5
ultralytics
imageio==2.9.0
imageio-ffmpeg>=0.4.3
2 changes: 1 addition & 1 deletion src/StrongSORT-YOLO/run_tracker.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
cd src/StrongSORT-YOLO
python3 tracker.py
python3 tracker.py $1
cd ..; cd ..;
71 changes: 43 additions & 28 deletions src/StrongSORT-YOLO/tracker.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,51 @@
from pathlib import Path
import track_v7
import sys
import os
import pickle
import multiprocessing as mp
import threading as th
import time

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
WEIGHTS = ROOT / 'weights'

def track_person(res, source_mov:str, idx:int):
""" tracks persons in video and puts data in out_queue"""
out_array_pr, vid_path = track_v7.run(source = source_mov, classes= [1, 2], yolo_weights =
WEIGHTS / 'best.pt', save_vid=False, ret=True)

ROOT = FILE.parents[0]
WEIGHTS = ROOT / "weights"


def track_person(res, source_mov: str, idx: int):
"""tracks persons in video and puts data in out_queue"""
out_array_pr, vid_path = track_v7.run(
source=source_mov,
classes=[1, 2],
yolo_weights=WEIGHTS / "best.pt",
save_vid=False,
ret=True,
)

res[idx] = (out_array_pr, vid_path)

print("==============Put data from tracking person and rim============")
return

def track_basketball(res, source_mov:str, idx:int):
""" tracks basketball in video and puts data in out_queue"""
out_array_bb, bb_vid_path = track_v7.run(source = source_mov,
yolo_weights = WEIGHTS / 'best_basketball.pt',
save_vid=False, ret=True, skip_big=True)


def track_basketball(res, source_mov: str, idx: int):
"""tracks basketball in video and puts data in out_queue"""
out_array_bb, bb_vid_path = track_v7.run(
source=source_mov,
yolo_weights=WEIGHTS / "best_basketball.pt",
save_vid=False,
ret=True,
skip_big=True,
)

res[idx] = (out_array_bb, bb_vid_path)

print("==============Put data from tracking basketball============")
return

def get_data(source_mov:str):
""" returns dict as: {

def get_data(source_mov: str):
"""returns dict as: {
'basketball_data': (basketball_bounding_boxes, video_path),
'person_data': (persons_and_rim_bounding_boxes, vide_path)
}.
Expand All @@ -40,7 +54,6 @@ def get_data(source_mov:str):
source_mov (string path): path to video file
"""


# ---------THREADING APPROACH---------------

# res = [None] * 2
Expand Down Expand Up @@ -79,18 +92,20 @@ def get_data(source_mov:str):

end = time.time()

print(f'=============time elapsed: {end-start}=================')

return {'basketball_data': (out_array_bb, bb_vid_path), 'person_data': (out_array_pr, vid_path)}
print(f"=============time elapsed: {end-start}=================")

return {
"basketball_data": (out_array_bb, bb_vid_path),
"person_data": (out_array_pr, vid_path),
}

def test():
print('import worked')

if __name__ == '__main__':
# TODO change to actual video path
output = get_data('../../data/training_data.mp4')
print(output['basketball_data'])
with open('../../tmp/test_output.pickle', 'wb') as f:
if __name__ == "__main__":
video_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "..", sys.argv[1]
)
output = get_data(video_path)
print(output["basketball_data"])
print("MODEL RUN DONE")
with open("../../tmp/test_output.pickle", "wb") as f:
pickle.dump(output, f)

34 changes: 24 additions & 10 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import yaml
from modelrunner import ModelRunner
from processrunner import ProcessRunner

# before main is called:
# frontend boots up, awaits user to upload a video
# upload button triggers backend call to upload video to s3
# fetch the video from cloud and download to tmp/uploaded_video.mp4
# calls main


# load in configs from config.yaml
# initialise modelrunner and processrunner
# feed video into the yolo model, pass into modelrunner
Expand All @@ -20,7 +22,7 @@ def load_config(path):
"""
TODO Loads the config yaml file to read in parameters and settings.
"""
with open(path, 'r') as file:
with open(path, "r") as file:
config = yaml.safe_load(file)
return config

Expand All @@ -31,21 +33,33 @@ def main(video_path):
Input: Path of the user uploaded video.
Returns: Results of the processing, in string format TODO change to csv/json?
"""
config = load_config('config.yaml')
model_vars = config['model_vars']
config = load_config("config.yaml")
model_vars = config["model_vars"]

modelrunner = ModelRunner(video_path, model_vars)
modelrunner.run()
people_output, ball_output = modelrunner.fetch_output()
output_video_path = 'tmp/court_video.mp4'
output_video_path_reenc = 'tmp/court_video_reenc.mp4'
modelrunner.pose()
people_output, ball_output, pose_output = modelrunner.fetch_output()
output_video_path = "tmp/court_video.mp4"
output_video_path_reenc = "tmp/court_video_reenc.mp4"

processrunner = ProcessRunner(
video_path,
people_output,
ball_output,
output_video_path,
output_video_path_reenc,
)

processrunner = ProcessRunner(video_path, people_output, ball_output, output_video_path,
output_video_path_reenc)
processrunner.run()
results = processrunner.get_results()
return results


if __name__ == "__main__":
import sys

if __name__ == '__main__':
main('data/training_data.mp4')
if len(sys.argv) <= 1:
main("tmp/training_data.mp4")
else:
main(sys.argv[1]) # Pass the first command-line argument to the main function
24 changes: 18 additions & 6 deletions src/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pickle
import subprocess
from typing import Tuple
from pose_estimation.pose_estimate import PoseEstimator
from ultralytics import YOLO

class ModelRunner:
"""
Expand All @@ -15,6 +17,7 @@ class ModelRunner:
def __init__(self, video_path, model_vars) -> None:
self.video_path = video_path
self.frame_reduction_factor = model_vars['frame_reduction_factor']
self.pose_estimator = PoseEstimator(video_path=video_path)


def drop_frames(self, input_path) -> str:
Expand Down Expand Up @@ -47,19 +50,28 @@ def run(self):
"""
# comment first two lines out to exclude running the model
# self.drop_frames(self.video_path)
# subprocess.run(['bash', 'src/StrongSORT-YOLO/run_tracker.sh'])
subprocess.run(['bash', 'src/StrongSORT-YOLO/run_tracker.sh', self.video_path])
with open('tmp/output.pickle', 'rb') as f:
self.output_dict = pickle.load(f)


def fetch_output(self) -> Tuple[str, str]:
def pose(self):
model = YOLO('src/pose_estimation/best.pt')
results = model(
source = self.video_path,
show=False,
conf=0.3,
verbose = False
)
self.pose_estimator.estimate_pose(results = results)

def fetch_output(self) -> Tuple[str, str, str]:
"""
Converts the people and ball model output in self.output.dict into txt files.
Returns a tuple of the people and ball txt output paths.
"""
ball_list = [tuple(round(num) for num in tup)
ball_list = [tuple(round(num) for num in tup)
for tup in self.output_dict['basketball_data'][0]]
people_list = [tuple(round(num) for num in tup)
people_list = [tuple(round(num) for num in tup)
for tup in self.output_dict['person_data'][0]]
ball_data = [(' '.join(map(str, ball[0:7])) + ' -1 -1 -1 -1')
for ball in ball_list]
Expand All @@ -72,4 +84,4 @@ def fetch_output(self) -> Tuple[str, str]:
with open('tmp/people.txt', 'w') as f:
f.write('\n'.join(people_data))

return 'tmp/people.txt', 'tmp/ball.txt'
return 'tmp/people.txt', 'tmp/ball.txt', 'tmp/pose.txt'
Binary file added src/pose_estimation/best.pt
Binary file not shown.
Loading

0 comments on commit 2817818

Please sign in to comment.