Skip to content

Commit

Permalink
Adjusted main to accept command-line arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikonooooo committed Oct 23, 2023
1 parent df676ed commit 98ae56b
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import yaml
from modelrunner import ModelRunner
from processrunner import ProcessRunner

# before main is called:
# frontend boots up, awaits user to upload a video
# upload button triggers backend call to upload video to s3
# fetch the video from cloud and download to tmp/uploaded_video.mp4
# calls main


# load in configs from config.yaml
# initialise modelrunner and processrunner
# feed video into the yolo model, pass into modelrunner
Expand All @@ -20,7 +22,7 @@ def load_config(path):
"""
TODO Loads the config yaml file to read in parameters and settings.
"""
with open(path, 'r') as file:
with open(path, "r") as file:
config = yaml.safe_load(file)
return config

Expand All @@ -31,21 +33,31 @@ def main(video_path):
Input: Path of the user uploaded video.
Returns: Results of the processing, in string format TODO change to csv/json?
"""
config = load_config('config.yaml')
model_vars = config['model_vars']
config = load_config("config.yaml")
model_vars = config["model_vars"]

modelrunner = ModelRunner(video_path, model_vars)
modelrunner.run()
people_output, ball_output = modelrunner.fetch_output()
output_video_path = 'tmp/court_video.mp4'
output_video_path_reenc = 'tmp/court_video_reenc.mp4'
output_video_path = "tmp/court_video.mp4"
output_video_path_reenc = "tmp/court_video_reenc.mp4"

processrunner = ProcessRunner(video_path, people_output, ball_output, output_video_path,
output_video_path_reenc)
processrunner = ProcessRunner(
video_path,
people_output,
ball_output,
output_video_path,
output_video_path_reenc,
)
processrunner.run()
results = processrunner.get_results()
return results


if __name__ == '__main__':
main('data/training_data.mp4')
if __name__ == "__main__":
import sys

if len(sys.argv) <= 1:
main("tmp/training_data.mp4")
else:
main(sys.argv[1]) # Pass the first command-line argument to the main function

0 comments on commit 98ae56b

Please sign in to comment.