diff --git a/src/main.py b/src/main.py index a590d6ac..4d4c5eba 100644 --- a/src/main.py +++ b/src/main.py @@ -4,12 +4,14 @@ import yaml from modelrunner import ModelRunner from processrunner import ProcessRunner + # before main is called: # frontend boots up, awaits user to upload a video # upload button triggers backend call to upload video to s3 # fetch the video from cloud and download to tmp/uploaded_video.mp4 # calls main + # load in configs from config.yaml # initialise modelrunner and processrunner # feed video into the yolo model, pass into modelrunner @@ -20,7 +22,7 @@ def load_config(path): """ TODO Loads the config yaml file to read in parameters and settings. """ - with open(path, 'r') as file: + with open(path, "r") as file: config = yaml.safe_load(file) return config @@ -31,8 +33,8 @@ def main(video_path): Input: Path of the user uploaded video. Returns: Results of the processing, in string format TODO change to csv/json? """ - config = load_config('config.yaml') - model_vars = config['model_vars'] + config = load_config("config.yaml") + model_vars = config["model_vars"] modelrunner = ModelRunner(video_path, model_vars) modelrunner.run() @@ -48,5 +50,10 @@ def main(video_path): return results -if __name__ == '__main__': - main('data/training_data.mp4') +if __name__ == "__main__": + import sys + + if len(sys.argv) <= 1: + main("tmp/training_data.mp4") + else: + main(sys.argv[1]) # Pass the first command-line argument to the main function