From 6d701cde6f4f150b751d0209f30596a96883db28 Mon Sep 17 00:00:00 2001 From: yakhyo Date: Mon, 2 Sep 2024 09:07:44 +0000 Subject: [PATCH] feat: Add param for function --- detect.py | 2 +- evaluate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/detect.py b/detect.py index 01efab5..5211aaf 100644 --- a/detect.py +++ b/detect.py @@ -77,7 +77,7 @@ def main(params): logging.info(f"Exception occured while loading pre-trained weights of face detection model. Exception: {e}") try: - gaze_detector = get_model(params.arch, params.bins) + gaze_detector = get_model(params.arch, params.bins, inference_mode=True) state_dict = torch.load(params.gaze_weights, map_location=device) gaze_detector.load_state_dict(state_dict) logging.info("Gaze Estimation model weights loaded.") diff --git a/evaluate.py b/evaluate.py index abe12e3..064832a 100644 --- a/evaluate.py +++ b/evaluate.py @@ -101,7 +101,7 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cudnn.benchmark = True - model = get_model(params.arch, params.bins) + model = get_model(params.arch, params.bins, inference_mode=True) if os.path.exists(params.weights): model.load_state_dict(torch.load(params.weights, map_location=device, weights_only=True))