Skip to content

Commit

Permalink
feat: Add param for function
Browse files Browse the repository at this point in the history
  • Loading branch information
yakhyo committed Sep 2, 2024
1 parent 6310bbf commit 6d701cd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main(params):
logging.info(f"Exception occured while loading pre-trained weights of face detection model. Exception: {e}")

try:
gaze_detector = get_model(params.arch, params.bins)
gaze_detector = get_model(params.arch, params.bins, inference_mode=True)
state_dict = torch.load(params.gaze_weights, map_location=device)
gaze_detector.load_state_dict(state_dict)
logging.info("Gaze Estimation model weights loaded.")
Expand Down
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

model = get_model(params.arch, params.bins)
model = get_model(params.arch, params.bins, inference_mode=True)

if os.path.exists(params.weights):
model.load_state_dict(torch.load(params.weights, map_location=device, weights_only=True))
Expand Down

0 comments on commit 6d701cd

Please sign in to comment.