Skip to content

Commit

Permalink
Merge pull request #8 from nextcloud/feat/compute-device
Browse files Browse the repository at this point in the history
feat: detect computation device for Whisper model
  • Loading branch information
marcelklehr authored Oct 30, 2024
2 parents 13f70af + 3495e5c commit 4ca8f89
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04

RUN \
apt update && \
Expand Down
17 changes: 14 additions & 3 deletions lib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@
import os

from fastapi import Depends, FastAPI, UploadFile, responses
from nc_py_api import AsyncNextcloudApp, NextcloudApp
from nc_py_api.ex_app import LogLvl, anc_app, run_app, set_handlers, persistent_storage
from faster_whisper import WhisperModel
from nc_py_api import AsyncNextcloudApp, NextcloudApp
from nc_py_api.ex_app import (
anc_app,
get_computation_device,
LogLvl,
persistent_storage,
run_app,
set_handlers,
)


def load_models():
Expand All @@ -28,7 +35,11 @@ def load_models():
return models

def create_model_loader(file_path):
return lambda: WhisperModel(file_path, device="cpu")
device = get_computation_device().lower()
if device != "cuda": # other GPUs are currently not supported by Whisper
device = "cpu"

return lambda: WhisperModel(file_path, device=device)


models = load_models()
Expand Down

0 comments on commit 4ca8f89

Please sign in to comment.