diff --git a/core/backend/llm.go b/core/backend/llm.go index a4d1e5f35e42..bf968c4f2a71 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -57,7 +57,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if _, err := os.Stat(modelFile); os.IsNotExist(err) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it - err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + err := gallery.InstallModelFromGallery(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) if err != nil { return nil, err } diff --git a/core/cli/models.go b/core/cli/models.go index d513858534e7..0e01eddca7ae 100644 --- a/core/cli/models.go +++ b/core/cli/models.go @@ -7,6 +7,7 @@ import ( cliContext "github.com/go-skynet/LocalAI/core/cli/context" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/startup" "github.com/rs/zerolog/log" "github.com/schollz/progressbar/v3" ) @@ -52,13 +53,12 @@ func (ml *ModelsList) Run(ctx *cliContext.Context) error { } func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { + var galleries []gallery.Gallery + if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { + log.Error().Err(err).Msg("unable to load galleries") + } for _, modelName := range mi.ModelArgs { - var galleries []gallery.Gallery - if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { - log.Error().Err(err).Msg("unable to load galleries") - } - progressBar := progressbar.NewOptions( 1000, progressbar.OptionSetDescription(fmt.Sprintf("downloading model %s", modelName)), @@ -72,7 +72,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { log.Error().Err(err).Str("filename", fileName).Int("value", v).Msg("error while updating progress bar") } } - + //startup.InstallModels() models, err := gallery.AvailableGalleryModels(galleries, mi.ModelsPath) if err != nil { return err @@ -85,7 +85,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { } log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") - err = gallery.InstallModelFromGalleryByName(galleries, modelName, mi.ModelsPath, gallery.GalleryModel{}, progressCallback) + err = startup.InstallModels(galleries, "", mi.ModelsPath, progressCallback, modelName) if err != nil { return err } diff --git a/core/services/gallery.go b/core/services/gallery.go index 384f0c124ca7..01e8d8b452a4 100644 --- a/core/services/gallery.go +++ b/core/services/gallery.go @@ -6,7 +6,6 @@ import ( "fmt" "os" "path/filepath" - "strings" "sync" "github.com/go-skynet/LocalAI/core/config" @@ -120,16 +119,20 @@ func (g *GalleryService) Start(c context.Context, cl *config.BackendConfigLoader } err = gallery.DeleteModelFromSystem(g.appConfig.ModelPath, op.GalleryModelName, files) + if err != nil { + updateError(err) + continue + } } else { // if the request contains a gallery name, we apply the gallery from the gallery list if op.GalleryModelName != "" { - if strings.Contains(op.GalleryModelName, "@") { - err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryModelName, g.appConfig.ModelPath, op.Req, progressCallback) - } else { - err = gallery.InstallModelFromGalleryByName(op.Galleries, op.GalleryModelName, g.appConfig.ModelPath, op.Req, progressCallback) - } + err = gallery.InstallModelFromGallery(op.Galleries, op.GalleryModelName, g.appConfig.ModelPath, op.Req, progressCallback) } else if op.ConfigURL != "" { - startup.PreloadModelsConfigurations(op.ConfigURL, g.appConfig.ModelPath, op.ConfigURL) + err = startup.InstallModels(op.Galleries, op.ConfigURL, g.appConfig.ModelPath, progressCallback, op.ConfigURL) + if err != nil { + updateError(err) + continue + } err = cl.Preload(g.appConfig.ModelPath) } else { err = prepareModel(g.appConfig.ModelPath, op.Req, progressCallback) @@ -179,13 +182,8 @@ func processRequests(modelPath string, galleries []gallery.Gallery, requests []g err = prepareModel(modelPath, r.GalleryModel, utils.DisplayDownloadFunction) } else { - if strings.Contains(r.ID, "@") { - err = gallery.InstallModelFromGallery( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } else { - err = gallery.InstallModelFromGalleryByName( - galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) - } + err = gallery.InstallModelFromGallery( + galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) } } return err diff --git a/core/startup/startup.go b/core/startup/startup.go index c276e4acd522..5f4818ee800d 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -59,8 +59,9 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode } } - // - pkgStartup.PreloadModelsConfigurations(options.ModelLibraryURL, options.ModelPath, options.ModelsURL...) + if err := pkgStartup.InstallModels(options.Galleries, options.ModelLibraryURL, options.ModelPath, nil, options.ModelsURL...); err != nil { + log.Error().Err(err).Msg("error installing models") + } cl := config.NewBackendConfigLoader(options.ModelPath) ml := model.NewModelLoader(options.ModelPath) diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index a18404f5fc56..9a8369fbbf5b 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -18,7 +18,7 @@ type Gallery struct { Name string `json:"name" yaml:"name"` } -// Installs a model from the gallery (galleryname@modelname) +// Installs a model from the gallery func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { applyModel := func(model *GalleryModel) error { @@ -114,11 +114,6 @@ func FindModel(models []*GalleryModel, name string, basePath string) *GalleryMod return model } -// InstallModelFromGalleryByName is planned for deprecation -func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { - return InstallModelFromGallery(galleries, name, basePath, req, downloadStatus) -} - // List available models // Models galleries are a list of yaml files that are hosted on a remote server (for example github). // Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting. diff --git a/pkg/startup/model_preload.go b/pkg/startup/model_preload.go index 240fc6bdf28b..aa732ab0be8c 100644 --- a/pkg/startup/model_preload.go +++ b/pkg/startup/model_preload.go @@ -2,19 +2,24 @@ package startup import ( "errors" + "fmt" "os" "path/filepath" "github.com/go-skynet/LocalAI/embedded" "github.com/go-skynet/LocalAI/pkg/downloader" + "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) -// PreloadModelsConfigurations will preload models from the given list of URLs +// InstallModels will preload models from the given list of URLs and galleries // It will download the model if it is not already present in the model path // It will also try to resolve if the model is an embedded model YAML configuration -func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) { +func InstallModels(galleries []gallery.Gallery, modelLibraryURL string, modelPath string, downloadStatus func(string, string, string, float64), models ...string) error { + // create an error that groups all errors + var err error + for _, url := range models { // As a best effort, try to resolve the model from the remote library @@ -32,18 +37,20 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model url = embedded.ModelShortURL(url) switch { case embedded.ExistsInModelsLibrary(url): - modelYAML, err := embedded.ResolveContent(url) + modelYAML, e := embedded.ResolveContent(url) // If we resolve something, just save it to disk and continue - if err != nil { - log.Error().Err(err).Msg("error resolving model content") + if e != nil { + log.Error().Err(e).Msg("error resolving model content") + err = errors.Join(err, e) continue } log.Debug().Msgf("[startup] resolved embedded model: %s", url) md5Name := utils.MD5(url) modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { - log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") + if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); err != nil { + log.Error().Err(e).Str("filepath", modelDefinitionFilePath).Msg("error writing model definition") + err = errors.Join(err, e) } case downloader.LooksLikeURL(url): log.Debug().Msgf("[startup] resolved model to download: %s", url) @@ -52,34 +59,70 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model md5Name := utils.MD5(url) // check if file exists - if _, err := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(err, os.ErrNotExist) { + if _, e := os.Stat(filepath.Join(modelPath, md5Name)); errors.Is(e, os.ErrNotExist) { modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - err := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { + e := downloader.DownloadFile(url, modelDefinitionFilePath, "", 0, 0, func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) }) - if err != nil { - log.Error().Err(err).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + if e != nil { + log.Error().Err(e).Str("url", url).Str("filepath", modelDefinitionFilePath).Msg("error downloading model") + err = errors.Join(err, e) } } default: - if _, err := os.Stat(url); err == nil { + if _, e := os.Stat(url); e == nil { log.Debug().Msgf("[startup] resolved local model: %s", url) // copy to modelPath md5Name := utils.MD5(url) - modelYAML, err := os.ReadFile(url) - if err != nil { - log.Error().Err(err).Str("filepath", url).Msg("error reading model definition") + modelYAML, e := os.ReadFile(url) + if e != nil { + log.Error().Err(e).Str("filepath", url).Msg("error reading model definition") + err = errors.Join(err, e) continue } modelDefinitionFilePath := filepath.Join(modelPath, md5Name) + ".yaml" - if err := os.WriteFile(modelDefinitionFilePath, modelYAML, os.ModePerm); err != nil { + if e := os.WriteFile(modelDefinitionFilePath, modelYAML, 0600); e != nil { log.Error().Err(err).Str("filepath", modelDefinitionFilePath).Msg("error loading model: %s") + err = errors.Join(err, e) } } else { - log.Warn().Msgf("[startup] failed resolving model '%s'", url) + // Check if it's a model gallery, or print a warning + e, found := installModel(galleries, url, modelPath, downloadStatus) + if e != nil && found { + log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url) + err = errors.Join(err, e) + } else if !found { + log.Warn().Msgf("[startup] failed resolving model '%s'", url) + err = errors.Join(err, fmt.Errorf("failed resolving model '%s'", url)) + } } } } + return err +} + +func installModel(galleries []gallery.Gallery, modelName, modelPath string, downloadStatus func(string, string, string, float64)) (error, bool) { + models, err := gallery.AvailableGalleryModels(galleries, modelPath) + if err != nil { + return err, false + } + + model := gallery.FindModel(models, modelName, modelPath) + if model == nil { + return err, false + } + + if downloadStatus == nil { + downloadStatus = utils.DisplayDownloadFunction + } + + log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model") + err = gallery.InstallModelFromGallery(galleries, modelName, modelPath, gallery.GalleryModel{}, downloadStatus) + if err != nil { + return err, true + } + + return nil, true } diff --git a/pkg/startup/model_preload_test.go b/pkg/startup/model_preload_test.go index 63a8f8b03e3b..e5c92bc1eafb 100644 --- a/pkg/startup/model_preload_test.go +++ b/pkg/startup/model_preload_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" + "github.com/go-skynet/LocalAI/pkg/gallery" . "github.com/go-skynet/LocalAI/pkg/startup" "github.com/go-skynet/LocalAI/pkg/utils" @@ -21,7 +22,7 @@ var _ = Describe("Preload test", func() { libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml" fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719") - PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2") + InstallModels([]gallery.Gallery{}, libraryURL, tmpdir, nil, "phi-2") resultFile := filepath.Join(tmpdir, fileName) @@ -37,7 +38,7 @@ var _ = Describe("Preload test", func() { url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - PreloadModelsConfigurations("", tmpdir, url) + InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) resultFile := filepath.Join(tmpdir, fileName) @@ -51,7 +52,7 @@ var _ = Describe("Preload test", func() { Expect(err).ToNot(HaveOccurred()) url := "phi-2" - PreloadModelsConfigurations("", tmpdir, url) + InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) entry, err := os.ReadDir(tmpdir) Expect(err).ToNot(HaveOccurred()) @@ -69,7 +70,7 @@ var _ = Describe("Preload test", func() { url := "mistral-openorca" fileName := fmt.Sprintf("%s.yaml", utils.MD5(url)) - PreloadModelsConfigurations("", tmpdir, url) + InstallModels([]gallery.Gallery{}, "", tmpdir, nil, url) resultFile := filepath.Join(tmpdir, fileName)