diff --git a/.github/workflows/bump_deps.yaml b/.github/workflows/bump_deps.yaml index 4344ac2b0e18..f8fd93d8caf1 100644 --- a/.github/workflows/bump_deps.yaml +++ b/.github/workflows/bump_deps.yaml @@ -12,6 +12,9 @@ jobs: - repository: "go-skynet/go-llama.cpp" variable: "GOLLAMA_VERSION" branch: "master" + - repository: "ggerganov/llama.cpp" + variable: "CPPLLAMA_VERSION" + branch: "master" - repository: "go-skynet/go-ggml-transformers.cpp" variable: "GOGGMLTRANSFORMERS_VERSION" branch: "master" diff --git a/Makefile b/Makefile index 8ad4c579d562..a2858093ff21 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ GOLLAMA_VERSION?=1676dcd7a139b6cdfbaea5fd67f46dc25d9d8bcf GOLLAMA_STABLE_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7 +CPPLLAMA_VERSION?=24ba3d829e31a6eda3fa1723f692608c2fa3adda + # gpt4all version GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_VERSION?=27a8b020c36b0df8f8b82a252d261cda47cf44b8 @@ -120,7 +122,7 @@ ifeq ($(findstring tts,$(GO_TAGS)),tts) OPTIONAL_GRPC+=backend-assets/grpc/piper endif -GRPC_BACKENDS?=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) +GRPC_BACKENDS?=backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/llama-cpp backend-assets/grpc/llama-stable backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) .PHONY: all test build vendor @@ -280,6 +282,7 @@ clean: ## Remove build related file rm -rf ./go-ggllm rm -rf $(BINARY_NAME) rm -rf release/ + $(MAKE) -C backend/cpp/llama clean ## Build: @@ -395,6 +398,16 @@ ifeq ($(BUILD_TYPE),metal) cp go-llama/build/bin/ggml-metal.metal backend-assets/grpc/ endif +backend/cpp/llama/grpc-server: + LLAMA_VERSION=$(CPPLLAMA_VERSION) $(MAKE) -C backend/cpp/llama grpc-server + +backend-assets/grpc/llama-cpp: backend-assets/grpc backend/cpp/llama/grpc-server + cp -rfv backend/cpp/llama/grpc-server backend-assets/grpc/llama-cpp +# TODO: every binary should have its own folder instead, so can have different metal implementations +ifeq ($(BUILD_TYPE),metal) + cp backend/cpp/llama/llama.cpp/build/bin/ggml-metal.metal backend-assets/grpc/ +endif + backend-assets/grpc/llama-stable: backend-assets/grpc go-llama-stable/libbinding.a $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama-stable CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama-stable LIBRARY_PATH=$(shell pwd)/go-llama \ diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index d719e29e8d41..16b63affb5bf 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -1,16 +1,22 @@ +// llama.cpp gRPC C++ backend server +// +// Ettore Di Giacinto +// +// This is a gRPC server for llama.cpp compatible with the LocalAI proto +// Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP, +// but modified to work with gRPC +// #include #include #include +#include #include "common.h" #include "llama.h" #include "grammar-parser.h" #include "backend.pb.h" #include "backend.grpc.pb.h" -// #include "absl/flags/flag.h" -// #include "absl/flags/parse.h" -// #include "absl/strings/str_format.h" // include std::regex #include @@ -94,17 +100,21 @@ static std::string tokens_to_str(llama_context *ctx, Iter begin, Iter end) } -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - if (server_verbose) \ - { \ - printf("VERBOSE", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ - } while (0) -#define LOG_ERROR(MSG, ...) printf("ERROR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) printf("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO(MSG, ...) printf("INFO", __func__, __LINE__, MSG, __VA_ARGS__) +// format incomplete utf-8 multibyte character for output +static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) +{ + std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + // if the size is 1 and first bit is 1, meaning it's a partial character + // (size > 1 meaning it's already a known token) + if (out.size() == 1 && (out[0] & 0x80) == 0x80) + { + std::stringstream ss; + ss << std::hex << (out[0] & 0xff); + std::string res(ss.str()); + out = "byte: \\x" + res; + } + return out; +} struct llama_server_context { @@ -184,11 +194,13 @@ struct llama_server_context bool loadModel(const gpt_params ¶ms_) { + printf("load model %s\n", params_.model.c_str()); + params = params_; std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == nullptr) { - LOG_ERROR("unable to load model", params_.model); + printf("unable to load model %s\n", params_.model.c_str()); return false; } n_ctx = llama_n_ctx(ctx); @@ -244,7 +256,7 @@ struct llama_server_context parsed_grammar = grammar_parser::parse(params.grammar.c_str()); // will be empty (default) if there are parse errors if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error", params.grammar); + printf("grammar parse error"); return false; } grammar_parser::print_grammar(stderr, parsed_grammar); @@ -252,7 +264,7 @@ struct llama_server_context { auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail",""); + printf("EOS token is disabled, which will cause most grammars to fail"); } } @@ -563,7 +575,6 @@ struct llama_server_context stopped_limit = true; } - return token_with_probs; } @@ -572,7 +583,7 @@ struct llama_server_context static const int n_embd = llama_n_embd(model); if (!params.embedding) { - LOG_WARNING("embedding disabled", ""); + printf("embedding disabled"); return std::vector(n_embd, 0.0f); } const float *data = llama_get_embeddings(ctx); @@ -587,7 +598,7 @@ static void parse_options_completion(bool streaming,const backend::PredictOption gpt_params default_params; llama.stream = streaming; - llama.params.n_predict = predict->tokens(); + llama.params.n_predict = predict->tokens() == 0 ? -1 : predict->tokens(); llama.params.top_k = predict->topk(); llama.params.top_p = predict->topp(); llama.params.tfs_z = predict->tailfreesamplingz(); @@ -652,13 +663,7 @@ static void parse_options_completion(bool streaming,const backend::PredictOption static void params_parse(const backend::ModelOptions* request, gpt_params & params) { - - std::string arg; - bool invalid_param = false; - - params.model = request->modelfile(); - // params.model_alias ?? params.model_alias = request->modelfile(); params.n_ctx = request->contextsize(); @@ -666,24 +671,30 @@ static void params_parse(const backend::ModelOptions* request, params.n_threads = request->threads(); params.n_gpu_layers = request->ngpulayers(); params.n_batch = request->nbatch(); - - std::string arg_next = request->tensorsplit(); - // split string by , and / - const std::regex regex{ R"([,/]+)" }; - std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; - std::vector split_arg{ it, {} }; - GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + if (!request->tensorsplit().empty()) { + std::string arg_next = request->tensorsplit(); - for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) { - if (i_device < split_arg.size()) { - params.tensor_split[i_device] = std::stof(split_arg[i_device]); - } - else { - params.tensor_split[i_device] = 0.0f; + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + + GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + + for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) { + if (i_device < split_arg.size()) { + params.tensor_split[i_device] = std::stof(split_arg[i_device]); + } + else { + params.tensor_split[i_device] = 0.0f; + } } } - params.main_gpu = std::stoi(request->maingpu()); + + if (!request->maingpu().empty()) { + params.main_gpu = std::stoi(request->maingpu()); + } // TODO: lora needs also a scale factor //params.lora_adapter = request->loraadapter(); //params.lora_base = request->lorabase(); @@ -749,37 +760,33 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con } } +// GRPC Server start class BackendServiceImpl final : public backend::Backend::Service { // The class has a llama instance that is shared across all RPCs llama_server_context llama; public: grpc::Status Health(ServerContext* context, const backend::HealthMessage* request, backend::Reply* reply) { // Implement Health RPC + reply->set_message("OK"); return Status::OK; } grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { // Implement LoadModel RPC - gpt_params params; - - // struct that contains llama context and inference - // llama_server_context llama; - params_parse(request, params); - llama_backend_init(params.numa); - // load the model if (!llama.loadModel(params)) { - // result->set_message(backend::Result::ERROR); - + result->set_message("Failed loading model"); + result->set_success(false); return Status::CANCELLED; } - + result->set_message("Loading succeeded"); + result->set_success(true); return Status::OK; } grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { @@ -802,79 +809,74 @@ class BackendServiceImpl final : public backend::Backend::Service { llama.loadPrompt(request->prompt()); llama.beginCompletion(); - size_t sent_count = 0; - size_t sent_token_probs_index = 0; + size_t sent_count = 0; + size_t sent_token_probs_index = 0; - while (llama.has_next_token) { - const completion_token_output token_with_probs = llama.doCompletion(); - if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { - continue; - } - const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); - - size_t pos = std::min(sent_count, llama.generated_text.size()); - - const std::string str_test = llama.generated_text.substr(pos); - bool is_stop_full = false; - size_t stop_pos = - llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); - if (stop_pos != std::string::npos) { - is_stop_full = true; - llama.generated_text.erase( - llama.generated_text.begin() + pos + stop_pos, - llama.generated_text.end()); - pos = std::min(sent_count, llama.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = llama.findStoppingStrings(str_test, token_text.size(), - STOP_PARTIAL); - } + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { + continue; + } + const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + + size_t pos = std::min(sent_count, llama.generated_text.size()); + + const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama.generated_text.erase( + llama.generated_text.begin() + pos + stop_pos, + llama.generated_text.end()); + pos = std::min(sent_count, llama.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + STOP_PARTIAL); + } - if ( - stop_pos == std::string::npos || - // Send rest of the text if we are at the end of the generation - (!llama.has_next_token && !is_stop_full && stop_pos > 0) - ) { - const std::string to_send = llama.generated_text.substr(pos, std::string::npos); - - sent_count += to_send.size(); - - std::vector probs_output = {}; - - if (llama.params.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - if (probs_pos < probs_stop_pos) { - probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); - } - sent_token_probs_index = probs_stop_pos; - } - backend::Reply reply; - reply.set_message(to_send); - - // Send the reply - writer->Write(reply); + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + + std::vector probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); } + sent_token_probs_index = probs_stop_pos; } + backend::Reply reply; + reply.set_message(to_send); - llama_print_timings(llama.ctx); - - - - llama.mutex.unlock(); - lock.release(); + // Send the reply + writer->Write(reply); + } + } + llama_print_timings(llama.ctx); + + llama.mutex.unlock(); + lock.release(); return grpc::Status::OK; - } - grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { + } - auto lock = llama.lock(); + grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) { + auto lock = llama.lock(); llama.rewind(); - llama_reset_timings(llama.ctx); - parse_options_completion(false, request, llama); if (!llama.loadGrammar()) @@ -886,159 +888,43 @@ class BackendServiceImpl final : public backend::Backend::Service { llama.loadPrompt(request->prompt()); llama.beginCompletion(); - // if (!llama.stream) { - if (llama.params.n_beams) { - // Fill llama.generated_token_probs vector with final beam. - llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, - llama.n_past, llama.n_remain); - // Translate llama.generated_token_probs to llama.generated_text. - append_to_generated_text_from_generated_token_probs(llama); - } else { - size_t stop_pos = std::string::npos; - - while (llama.has_next_token) { - const completion_token_output token_with_probs = llama.doCompletion(); - const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); - - stop_pos = llama.findStoppingStrings(llama.generated_text, - token_text.size(), STOP_FULL); - } - - if (stop_pos == std::string::npos) { - stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); - } - if (stop_pos != std::string::npos) { - llama.generated_text.erase(llama.generated_text.begin() + stop_pos, - llama.generated_text.end()); - } + if (llama.params.n_beams) { + // Fill llama.generated_token_probs vector with final beam. + llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams, + llama.n_past, llama.n_remain); + // Translate llama.generated_token_probs to llama.generated_text. + append_to_generated_text_from_generated_token_probs(llama); + } else { + size_t stop_pos = std::string::npos; + + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(llama.ctx, token_with_probs.tok); + + stop_pos = llama.findStoppingStrings(llama.generated_text, + token_text.size(), STOP_FULL); } - auto probs = llama.generated_token_probs; - if (llama.params.n_probs > 0 && llama.stopped_word) { - const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); - probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); + if (stop_pos == std::string::npos) { + stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL); + } + if (stop_pos != std::string::npos) { + llama.generated_text.erase(llama.generated_text.begin() + stop_pos, + llama.generated_text.end()); } + } - // const json data = format_final_response(llama, llama.generated_text, probs); + auto probs = llama.generated_token_probs; + if (llama.params.n_probs > 0 && llama.stopped_word) { + const std::vector stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false); + probs = std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size()); + } reply->set_message(llama.generated_text); return grpc::Status::OK; - // llama_print_timings(llama.ctx); - - // res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), - // "application/json"); - // } else { - // const auto chunked_content_provider = [&](size_t, DataSink & sink) { - // size_t sent_count = 0; - // size_t sent_token_probs_index = 0; - - // while (llama.has_next_token) { - // const completion_token_output token_with_probs = llama.doCompletion(); - // if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { - // continue; - // } - // const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); - - // size_t pos = std::min(sent_count, llama.generated_text.size()); - - // const std::string str_test = llama.generated_text.substr(pos); - // bool is_stop_full = false; - // size_t stop_pos = - // llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); - // if (stop_pos != std::string::npos) { - // is_stop_full = true; - // llama.generated_text.erase( - // llama.generated_text.begin() + pos + stop_pos, - // llama.generated_text.end()); - // pos = std::min(sent_count, llama.generated_text.size()); - // } else { - // is_stop_full = false; - // stop_pos = llama.findStoppingStrings(str_test, token_text.size(), - // STOP_PARTIAL); - // } - - // if ( - // stop_pos == std::string::npos || - // // Send rest of the text if we are at the end of the generation - // (!llama.has_next_token && !is_stop_full && stop_pos > 0) - // ) { - // const std::string to_send = llama.generated_text.substr(pos, std::string::npos); - - // sent_count += to_send.size(); - - // std::vector probs_output = {}; - - // if (llama.params.n_probs > 0) { - // const std::vector to_send_toks = llama_tokenize(llama.ctx, to_send, false); - // size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); - // size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); - // if (probs_pos < probs_stop_pos) { - // probs_output = std::vector(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); - // } - // sent_token_probs_index = probs_stop_pos; - // } - - // const json data = format_partial_response(llama, to_send, probs_output); - - // const std::string str = - // "data: " + - // data.dump(-1, ' ', false, json::error_handler_t::replace) + - // "\n\n"; - - // LOG_VERBOSE("data stream", { - // { "to_send", str } - // }); - - // if (!sink.write(str.data(), str.size())) { - // LOG_VERBOSE("stream closed", {}); - // llama_print_timings(llama.ctx); - // return false; - // } - // } - - // if (!llama.has_next_token) { - // // Generation is done, send extra information. - // const json data = format_final_response( - // llama, - // "", - // std::vector(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) - // ); - - // const std::string str = - // "data: " + - // data.dump(-1, ' ', false, json::error_handler_t::replace) + - // "\n\n"; - - // LOG_VERBOSE("data stream", { - // { "to_send", str } - // }); - - // if (!sink.write(str.data(), str.size())) { - // LOG_VERBOSE("stream closed", {}); - // llama_print_timings(llama.ctx); - // return false; - // } - // } - // } - - // llama_print_timings(llama.ctx); - // sink.done(); - // return true; - // }; - // const auto on_complete = [&](bool) { - // llama.mutex.unlock(); - // }; - // lock.release(); - // res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - // } - // Implement Predict RPC - // return Status::OK; - } - - // Implement other RPCs following the same pattern as above + } }; -void RunServer() { - std::string server_address("0.0.0.0:50051"); +void RunServer(const std::string& server_address) { BackendServiceImpl service; ServerBuilder builder; @@ -1051,6 +937,28 @@ void RunServer() { } int main(int argc, char** argv) { - RunServer(); + std::string server_address("localhost:50051"); + + // Define long and short options + struct option long_options[] = { + {"addr", required_argument, nullptr, 'a'}, + {nullptr, 0, nullptr, 0} + }; + + // Parse command-line arguments + int option; + int option_index = 0; + while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) { + switch (option) { + case 'a': + server_address = optarg; + break; + default: + std::cerr << "Usage: " << argv[0] << " [--addr=
] or [-a
]" << std::endl; + return 1; + } + } + + RunServer(server_address); return 0; } diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 7773eb1e8df0..5ad9500ba148 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -17,6 +17,7 @@ import ( const ( LlamaBackend = "llama" LlamaStableBackend = "llama-stable" + LLamaCPP = "llama-cpp" BloomzBackend = "bloomz" StarcoderBackend = "starcoder" GPTJBackend = "gptj" @@ -41,8 +42,9 @@ const ( ) var AutoLoadBackends []string = []string{ - LlamaBackend, + LLamaCPP, LlamaStableBackend, + LlamaBackend, Gpt4All, FalconBackend, GPTNeoXBackend, @@ -175,11 +177,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er } switch backend { - case LlamaBackend, LlamaStableBackend, GPTJBackend, DollyBackend, - MPTBackend, Gpt2Backend, FalconBackend, - GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend, - RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend: - return ml.LoadModel(o.model, ml.grpcModel(backend, o)) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") return ml.LoadModel(o.model, ml.grpcModel(Gpt4All, o)) @@ -187,7 +184,7 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data") return ml.LoadModel(o.model, ml.grpcModel(PiperBackend, o)) default: - return nil, fmt.Errorf("backend unsupported: %s", o.backendString) + return ml.LoadModel(o.model, ml.grpcModel(backend, o)) } }