From 98ae56bf2c63e6ce6ed73972c3035d92bea67b00 Mon Sep 17 00:00:00 2001 From: Michael Ngo Date: Mon, 23 Oct 2023 04:49:32 -0400 Subject: [PATCH] Adjusted main to accept command-line arg --- src/main.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/main.py b/src/main.py index 67fbd10d..e0a07db6 100644 --- a/src/main.py +++ b/src/main.py @@ -4,12 +4,14 @@ import yaml from modelrunner import ModelRunner from processrunner import ProcessRunner + # before main is called: # frontend boots up, awaits user to upload a video # upload button triggers backend call to upload video to s3 # fetch the video from cloud and download to tmp/uploaded_video.mp4 # calls main + # load in configs from config.yaml # initialise modelrunner and processrunner # feed video into the yolo model, pass into modelrunner @@ -20,7 +22,7 @@ def load_config(path): """ TODO Loads the config yaml file to read in parameters and settings. """ - with open(path, 'r') as file: + with open(path, "r") as file: config = yaml.safe_load(file) return config @@ -31,21 +33,31 @@ def main(video_path): Input: Path of the user uploaded video. Returns: Results of the processing, in string format TODO change to csv/json? """ - config = load_config('config.yaml') - model_vars = config['model_vars'] + config = load_config("config.yaml") + model_vars = config["model_vars"] modelrunner = ModelRunner(video_path, model_vars) modelrunner.run() people_output, ball_output = modelrunner.fetch_output() - output_video_path = 'tmp/court_video.mp4' - output_video_path_reenc = 'tmp/court_video_reenc.mp4' + output_video_path = "tmp/court_video.mp4" + output_video_path_reenc = "tmp/court_video_reenc.mp4" - processrunner = ProcessRunner(video_path, people_output, ball_output, output_video_path, - output_video_path_reenc) + processrunner = ProcessRunner( + video_path, + people_output, + ball_output, + output_video_path, + output_video_path_reenc, + ) processrunner.run() results = processrunner.get_results() return results -if __name__ == '__main__': - main('data/training_data.mp4') +if __name__ == "__main__": + import sys + + if len(sys.argv) <= 1: + main("tmp/training_data.mp4") + else: + main(sys.argv[1]) # Pass the first command-line argument to the main function