diff --git a/plugins/wasi_nn/wasinnenv.cpp b/plugins/wasi_nn/wasinnenv.cpp index 96cd471af103..5e0211cf8fa5 100644 --- a/plugins/wasi_nn/wasinnenv.cpp +++ b/plugins/wasi_nn/wasinnenv.cpp @@ -28,7 +28,7 @@ std::map BackendMap = { {"tensorflowlite"sv, Backend::TensorflowLite}, {"autodetect"sv, Backend::Autodetect}, {"ggml"sv, Backend::GGML}, - {"whisper"sv, Backend::GGML}}; + {"whisper"sv, Backend::WHISPER}}; std::map DeviceMap = {{"cpu"sv, Device::CPU}, {"gpu"sv, Device::GPU}, diff --git a/plugins/wasi_nn/whispercpp.cpp b/plugins/wasi_nn/whispercpp.cpp index b8f3769a4b9b..2c6e2fd7aeae 100644 --- a/plugins/wasi_nn/whispercpp.cpp +++ b/plugins/wasi_nn/whispercpp.cpp @@ -20,7 +20,51 @@ Expect load([[maybe_unused]] WasiNNEnvironment &Env, [[maybe_unused]] Span> Builders, [[maybe_unused]] Device Device, [[maybe_unused]] uint32_t &GraphId) noexcept { - // Env.NNGraph.emplace_back(Backend::WHISPER); + Env.NNGraph.emplace_back(Backend::WHISPER); + auto &GraphRef = Env.NNGraph.back().get(); + truct whisper_context_params ContextDefault = + whisper_context_default_params(); // from whisper.cpp + GraphRef.EnableLog = false; + GraphRef.EnableDebugLog = false; + GraphRef.StreamStdout = false; + auto Weight = Builders[0]; + const std::string BinModel(reinterpret_cast(Weight.data()), + Weight.size()); + + // Handle the model path. + std::string ModelFilePath; + if (BinModel.substr(0, 8) == "preload:") { + ModelFilePath = BinModel.substr(8); + } else { + // Write whisper model to file. + ModelFilePath = "models/ggml-base.en.bin"sv; + std::ofstream TempFile(ModelFilePath); + if (!TempFile) { + spdlog::error( + "[WASI-NN] Whisper backend: Failed to create the temporary file."); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + TempFile << BinModel; + TempFile.close(); + } + // auto==struct whisper_context definition available in whisper.cpp which is + // not included + auto wctx = whisper_init_from_file_with_params(GraphRef.ModelFilePath.c_str(), + ContextDefault); + GraphRef.WhisperModel = wctx.model; + + if (GraphRef.WhisperModel == nullptr) { + spdlog::error("[WASI-NN] Whisper backend: Error: unable to init model."sv); + Env.NNGraph.pop_back(); + return ErrNo::InvalidArgument; + } + if (GraphRef.EnableDebugLog) { + spdlog::info( + "[WASI-NN][Debug] Whisper backend: Initialize whisper model with given parameters...Done"sv); + } + // Store the loaded graph. + GraphId = Env.NNGraph.size() - 1; return ErrNo::Success; } @@ -68,7 +112,7 @@ finiSingle([[maybe_unused]] WASINN::WasiNNEnvironment &Env, namespace { Expect reportBackendNotSupported() noexcept { spdlog::error("[WASI-NN] whisper backend is not built. use " - "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"whisper\" to build it."sv); + "-WASMEDGE_PLUGIN_WASI_NN_BACKEND=\"WHISPER\" to build it."sv); return ErrNo::InvalidArgument; } } // namespace