diff --git a/Dockerfile b/Dockerfile index c069426..aed7ce0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 +FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 RUN \ apt update && \ diff --git a/lib/main.py b/lib/main.py index 0a38d9d..275b0ac 100644 --- a/lib/main.py +++ b/lib/main.py @@ -10,9 +10,16 @@ import os from fastapi import Depends, FastAPI, UploadFile, responses -from nc_py_api import AsyncNextcloudApp, NextcloudApp -from nc_py_api.ex_app import LogLvl, anc_app, run_app, set_handlers, persistent_storage from faster_whisper import WhisperModel +from nc_py_api import AsyncNextcloudApp, NextcloudApp +from nc_py_api.ex_app import ( + anc_app, + get_computation_device, + LogLvl, + persistent_storage, + run_app, + set_handlers, +) def load_models(): @@ -28,7 +35,11 @@ def load_models(): return models def create_model_loader(file_path): - return lambda: WhisperModel(file_path, device="cpu") + device = get_computation_device().lower() + if device != "cuda": # other GPUs are currently not supported by Whisper + device = "cpu" + + return lambda: WhisperModel(file_path, device=device) models = load_models()