Skip to content

Commit

Permalink
mahi
Browse files Browse the repository at this point in the history
  • Loading branch information
Jake0826 committed Oct 12, 2023
1 parent 0509398 commit 6d063f8
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 13,836 deletions.
Binary file removed data/training_data.mp4
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ requests
opencv-python==4.7.0.72

# StrongSORT
ultralytics

# base ----------------------------------------
matplotlib>=3.2.2
Pillow>=7.1.2
Expand Down
Binary file modified src/StrongSORT-YOLO/testing.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions src/botsort.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,64 @@
import cv2
from ultralytics import YOLO
from pathlib import Path
import pickle

def get_data_yolov8(source_mov: str, model_path: str):
""" Returns tracking info using YOLOv8.
Args:
source_mov (str): Path to video file.
model_path (str): Path to the YOLOv8 weights/model.
Returns:
dict: Contains 'basketball_data' and 'person_data'.
"""

# Load the YOLOv8 model
model = YOLO(model_path)

cap = cv2.VideoCapture(source_mov)
basketball_data = []
person_data = []

while cap.isOpened():
success, frame = cap.read()
if success:
results = model.track(frame, persist=True, tracker="botsort.yaml")
print(type(results))
print(results)
for detection in results.pred[0]:
label = detection[4].item()
# Assuming class index 1 is 'person' and class index 2 is 'basketball'
if label == 1:
person_data.append(detection[:4].tolist())
elif label == 2:
basketball_data.append(detection[:4].tolist())
else:
break

cap.release()

return {'basketball_data': (basketball_data, source_mov), 'person_data': (person_data, source_mov)}


def sort():
# TODO: Update with actual paths
video_path = '/data/training_data.mp4'
model_path = 'yolov8n.pt'

output = get_data_yolov8(video_path, model_path)
print("b")
print(output)
print(output['basketball_data'])
return output
with open('../../tmp/test_output_yolov8.pickle', 'wb') as f:
pickle.dump(output, f)



'''import cv2
from ultralytics import YOLO
# Load the YOLOv8 model
model = YOLO('yolov8n.pt')
Expand Down Expand Up @@ -34,3 +93,4 @@
# Release the video capture object and close the display window
cap.release()
cv2.destroyAllWindows()
'''
5 changes: 4 additions & 1 deletion src/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pickle
import subprocess
from typing import Tuple
from botsort import sort

class ModelRunner:
"""
Expand Down Expand Up @@ -46,10 +47,12 @@ def run(self):
"""
# comment first two lines out to exclude running the model
self.drop_frames(self.video_path)
self.output_dict = sort()
'''
subprocess.run(['bash', 'src/StrongSORT-YOLO/run_tracker.sh'])
with open('tmp/output.pickle', 'rb') as f:
self.output_dict = pickle.load(f)

'''

def fetch_output(self) -> Tuple[str, str]:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/view/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
if 'state' not in st.session_state:
st.session_state.state = 0
st.session_state.logo = 'src/view/static/basketball.png'
with open('data/training_data.mp4', 'rb') as file:
with open('training_data.mp4', 'rb') as file:
st.session_state.video_file = io.BytesIO(file.read())
st.session_state.processed_video = None
st.session_state.result_string = None
Expand Down
Loading

0 comments on commit 6d063f8

Please sign in to comment.