diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index 264d947b9913..a96e9829af16 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -11,17 +11,9 @@ import ( func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { - var inferenceModel interface{} - var err error + opts := ModelOptions(backendConfig, appConfig) - opts := ModelOptions(backendConfig, appConfig, []model.Option{}) - - if backendConfig.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - opts = append(opts, model.WithBackendString(backendConfig.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) - } + inferenceModel, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/image.go b/core/backend/image.go index 72c0007c5842..38ca43570fe8 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -9,9 +9,8 @@ import ( func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() error, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{}) - - inferenceModel, err := loader.BackendLoader( + opts := ModelOptions(backendConfig, appConfig) + inferenceModel, err := loader.Load( opts..., ) if err != nil { diff --git a/core/backend/llm.go b/core/backend/llm.go index 199a62338c84..4491a191eeb4 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -16,7 +16,6 @@ import ( "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/gallery" - "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" @@ -35,15 +34,6 @@ type TokenUsage struct { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model - var inferenceModel grpc.Backend - var err error - - opts := ModelOptions(c, o, []model.Option{}) - - if c.Backend != "" { - opts = append(opts, model.WithBackendString(c.Backend)) - } - // Check if the modelFile exists, if it doesn't try to load it from the gallery if o.AutoloadGalleries { // experimental if _, err := os.Stat(modelFile); os.IsNotExist(err) { @@ -56,12 +46,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im } } - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - inferenceModel, err = loader.BackendLoader(opts...) - } - + opts := ModelOptions(c, o) + inferenceModel, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/options.go b/core/backend/options.go index 6586eccf13fd..c65912222a58 100644 --- a/core/backend/options.go +++ b/core/backend/options.go @@ -11,7 +11,7 @@ import ( "github.com/rs/zerolog/log" ) -func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts []model.Option) []model.Option { +func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option { name := c.Name if name == "" { name = c.Model diff --git a/core/backend/rerank.go b/core/backend/rerank.go index f600e2e6eaff..8152ef7fc357 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -11,8 +11,8 @@ import ( func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) - rerankModel, err := loader.BackendLoader(opts...) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) + rerankModel, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index b1b458b447ab..a8d46478c7cd 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -25,9 +25,8 @@ func SoundGeneration( backendConfig config.BackendConfig, ) (string, *proto.Result, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) - - soundGenModel, err := loader.BackendLoader(opts...) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) + soundGenModel, err := loader.Load(opts...) if err != nil { return "", nil, err } diff --git a/core/backend/stores.go b/core/backend/stores.go index 1b514584cbeb..f5ee9166df8b 100644 --- a/core/backend/stores.go +++ b/core/backend/stores.go @@ -8,16 +8,15 @@ import ( ) func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) (grpc.Backend, error) { - if storeName == "" { - storeName = "default" - } + if storeName == "" { + storeName = "default" + } - sc := []model.Option{ - model.WithBackendString(model.LocalStoreBackend), - model.WithAssetDir(appConfig.AssetsDestination), - model.WithModel(storeName), - } + sc := []model.Option{ + model.WithBackendString(model.LocalStoreBackend), + model.WithAssetDir(appConfig.AssetsDestination), + model.WithModel(storeName), + } - return sl.BackendLoader(sc...) + return sl.Load(sc...) } - diff --git a/core/backend/token_metrics.go b/core/backend/token_metrics.go index acd256634a0a..cc71c8681e54 100644 --- a/core/backend/token_metrics.go +++ b/core/backend/token_metrics.go @@ -15,10 +15,8 @@ func TokenMetrics( appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.MetricsResponse, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithModel(modelFile), - }) - model, err := loader.BackendLoader(opts...) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) + model, err := loader.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index c8ec8d1cb260..2f813e18736b 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -14,15 +14,13 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac var inferenceModel grpc.Backend var err error - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithModel(modelFile), - }) + opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) if backendConfig.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) + inferenceModel, err = loader.Load(opts...) } else { opts = append(opts, model.WithBackendString(backendConfig.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) + inferenceModel, err = loader.Load(opts...) } if err != nil { return schema.TokenizeResponse{}, err diff --git a/core/backend/transcript.go b/core/backend/transcript.go index c6ad9b597795..372f6984237c 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -18,9 +18,9 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL backendConfig.Backend = model.WhisperBackend } - opts := ModelOptions(backendConfig, appConfig, []model.Option{}) + opts := ModelOptions(backendConfig, appConfig) - transcriptionModel, err := ml.BackendLoader(opts...) + transcriptionModel, err := ml.Load(opts...) if err != nil { return nil, err } diff --git a/core/backend/tts.go b/core/backend/tts.go index 20aa358e7257..f9be6955bcd6 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -28,12 +28,8 @@ func ModelTTS( bb = model.PiperBackend } - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithBackendString(bb), - model.WithModel(modelFile), - }) - - ttsModel, err := loader.BackendLoader(opts...) + opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile)) + ttsModel, err := loader.Load(opts...) if err != nil { return "", nil, err } diff --git a/core/startup/startup.go b/core/startup/startup.go index 17e54bc0603b..0eb5fa585585 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -160,15 +160,10 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode log.Debug().Msgf("Auto loading model %s into memory from file: %s", m, cfg.Model) - o := backend.ModelOptions(*cfg, options, []model.Option{}) + o := backend.ModelOptions(*cfg, options) var backendErr error - if cfg.Backend != "" { - o = append(o, model.WithBackendString(cfg.Backend)) - _, backendErr = ml.BackendLoader(o...) - } else { - _, backendErr = ml.GreedyLoader(o...) - } + _, backendErr = ml.Load(o...) if backendErr != nil { return nil, nil, nil, err } diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 5723e3e41db2..a5bedf79a7a6 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -455,7 +455,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error) return orderBackends(backends) } -func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { +func (ml *ModelLoader) backendLoader(opts ...Option) (client grpc.Backend, err error) { o := NewOptions(opts...) log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) @@ -500,7 +500,7 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo } } -func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { +func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) { o := NewOptions(opts...) // Return earlier if we have a model already loaded @@ -513,6 +513,10 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { ml.stopActiveBackends(o.modelID, o.singleActiveBackend) + if o.backendString != "" { + return ml.backendLoader(opts...) + } + var err error // get backends embedded in the binary @@ -536,7 +540,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) { WithBackendString(key), }...) - model, modelerr := ml.BackendLoader(options...) + model, modelerr := ml.backendLoader(options...) if modelerr == nil && model != nil { log.Info().Msgf("[%s] Loads OK", key) return model, nil diff --git a/tests/integration/stores_test.go b/tests/integration/stores_test.go index 4244d817fc04..5ed46b19649f 100644 --- a/tests/integration/stores_test.go +++ b/tests/integration/stores_test.go @@ -57,7 +57,7 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs" } sl = model.NewModelLoader("") - sc, err = sl.BackendLoader(storeOpts...) + sc, err = sl.Load(storeOpts...) Expect(err).ToNot(HaveOccurred()) Expect(sc).ToNot(BeNil()) })