From 3de55a81ce851017da6d1d0c222022caf5f8a457 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Mon, 14 Oct 2024 21:07:56 +0800 Subject: [PATCH] refactor: installing plugin --- internal/core/plugin_manager/manager.go | 12 ++ internal/db/init.go | 1 + internal/server/controllers/plugins.go | 2 +- internal/service/install_plugin.go | 243 ++++++++++++++++++++---- internal/types/app/config.go | 17 +- internal/types/app/default.go | 1 + internal/types/models/task.go | 27 +++ internal/utils/routine/pool.go | 40 ++++ 8 files changed, 300 insertions(+), 43 deletions(-) create mode 100644 internal/types/models/task.go diff --git a/internal/core/plugin_manager/manager.go b/internal/core/plugin_manager/manager.go index c22991a..84a4b66 100644 --- a/internal/core/plugin_manager/manager.go +++ b/internal/core/plugin_manager/manager.go @@ -24,6 +24,7 @@ type PluginManager struct { maxPluginPackageSize int64 workingDirectory string + packageCachePath string // mediaManager is used to manage media files like plugin icons, images, etc. mediaManager *media_manager.MediaManager @@ -55,6 +56,7 @@ var ( func NewManager(configuration *app.Config) *PluginManager { manager = &PluginManager{ maxPluginPackageSize: configuration.MaxPluginPackageSize, + packageCachePath: configuration.PluginPackageCachePath, workingDirectory: configuration.PluginWorkingPath, mediaManager: media_manager.NewMediaManager( configuration.PluginMediaCachePath, @@ -74,6 +76,7 @@ func NewManager(configuration *app.Config) *PluginManager { os.MkdirAll(configuration.PluginWorkingPath, 0755) os.MkdirAll(configuration.PluginStoragePath, 0755) os.MkdirAll(configuration.PluginMediaCachePath, 0755) + os.MkdirAll(configuration.PluginPackageCachePath, 0755) os.MkdirAll(filepath.Dir(configuration.ProcessCachingPath), 0755) return manager @@ -146,3 +149,12 @@ func (p *PluginManager) Init(configuration *app.Config) { func (p *PluginManager) BackwardsInvocation() dify_invocation.BackwardsInvocation { return p.backwardsInvocation } + +func (p *PluginManager) SavePackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, pkg []byte) error { + // save to storage + return os.WriteFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()), pkg, 0644) +} + +func (p *PluginManager) GetPackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) ([]byte, error) { + return os.ReadFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String())) +} diff --git a/internal/db/init.go b/internal/db/init.go index 9b7b97f..3c4455a 100644 --- a/internal/db/init.go +++ b/internal/db/init.go @@ -84,6 +84,7 @@ func autoMigrate() error { models.ServerlessRuntime{}, models.ToolInstallation{}, models.AIModelInstallation{}, + models.InstallTask{}, ) } diff --git a/internal/server/controllers/plugins.go b/internal/server/controllers/plugins.go index 1429166..727f557 100644 --- a/internal/server/controllers/plugins.go +++ b/internal/server/controllers/plugins.go @@ -59,7 +59,7 @@ func InstallPluginFromIdentifiers(app *app.Config) gin.HandlerFunc { return func(c *gin.Context) { BindRequest(c, func(request struct { TenantID string `uri:"tenant_id" validate:"required"` - PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,dive,plugin_unique_identifier"` + PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,max=64,dive,plugin_unique_identifier"` Source string `json:"source" validate:"required"` Meta map[string]any `json:"meta" validate:"omitempty"` }) { diff --git a/internal/service/install_plugin.go b/internal/service/install_plugin.go index 722d28b..8508462 100644 --- a/internal/service/install_plugin.go +++ b/internal/service/install_plugin.go @@ -5,8 +5,10 @@ import ( "fmt" "io" "mime/multipart" + "time" "github.com/gin-gonic/gin" + "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder" "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/verifier" "github.com/langgenius/dify-plugin-daemon/internal/db" @@ -15,6 +17,9 @@ import ( "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" "github.com/langgenius/dify-plugin-daemon/internal/types/models" "github.com/langgenius/dify-plugin-daemon/internal/types/models/curd" + "github.com/langgenius/dify-plugin-daemon/internal/utils/log" + "github.com/langgenius/dify-plugin-daemon/internal/utils/routine" + "gorm.io/gorm" ) func UploadPluginFromPkg( @@ -57,14 +62,194 @@ func InstallPluginFromIdentifiers( source string, meta map[string]any, ) *entities.Response { + var response struct { + AllInstalled bool `json:"all_installed"` + TaskID string `json:"task_id"` + } + // TODO: create installation task and dispatch to workers - for _, plugin_unique_identifier := range plugin_unique_identifiers { - if err := InstallPluginFromIdentifier(tenant_id, plugin_unique_identifier, source, meta); err != nil { + plugins_wait_for_installation := []plugin_entities.PluginUniqueIdentifier{} + + task := &models.InstallTask{ + Status: models.InstallTaskStatusRunning, + TotalPlugins: len(plugins_wait_for_installation), + CompletedPlugins: 0, + Plugins: []models.InstallTaskPluginStatus{}, + } + + for i, plugin_unique_identifier := range plugin_unique_identifiers { + // check if plugin is already installed + plugin, err := db.GetOne[models.Plugin]( + db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()), + ) + + task.Plugins = append(task.Plugins, models.InstallTaskPluginStatus{ + PluginUniqueIdentifier: plugin_unique_identifier, + PluginID: plugin_unique_identifier.PluginID(), + Status: models.InstallTaskStatusPending, + Message: "", + }) + + task.TotalPlugins++ + + if err == nil { + // already installed by other tenant + declaration := plugin.Declaration + if _, _, err := curd.InstallPlugin( + tenant_id, + plugin_unique_identifier, + plugin.InstallType, + &declaration, + source, + meta, + ); err != nil { + return entities.NewErrorResponse(-500, err.Error()) + } + + task.CompletedPlugins++ + task.Plugins[i].Status = models.InstallTaskStatusSuccess + task.Plugins[i].Message = "Installed" + continue + } + + if err != db.ErrDatabaseNotFound { return entities.NewErrorResponse(-500, err.Error()) } + + plugins_wait_for_installation = append(plugins_wait_for_installation, plugin_unique_identifier) } - return entities.NewSuccessResponse(true) + if len(plugins_wait_for_installation) == 0 { + response.AllInstalled = true + response.TaskID = "" + return entities.NewSuccessResponse(response) + } + + err := db.Create(task) + if err != nil { + return entities.NewErrorResponse(-500, err.Error()) + } + + response.TaskID = task.ID + + manager := plugin_manager.Manager() + + tasks := []func(){} + for _, plugin_unique_identifier := range plugins_wait_for_installation { + tasks = append(tasks, func() { + updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) { + if err := db.WithTransaction(func(tx *gorm.DB) error { + task, err := db.GetOne[models.InstallTask]( + db.WithTransactionContext(tx), + db.Equal("id", task.ID), + db.WLock(), // write lock, multiple tasks can't update the same task + ) + if err != nil { + return err + } + + task_pointer := &task + var plugin_status *models.InstallTaskPluginStatus + for _, plugin := range task.Plugins { + if plugin.PluginUniqueIdentifier == plugin_unique_identifier { + plugin_status = &plugin + } + } + modifier(task_pointer, plugin_status) + return db.Update(task_pointer, tx) + }); err != nil { + log.Error("failed to update install task status %s", err.Error()) + } + } + + pkg, err := manager.GetPackage(plugin_unique_identifier) + if err != nil { + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + task.Status = models.InstallTaskStatusFailed + plugin.Status = models.InstallTaskStatusFailed + plugin.Message = err.Error() + }) + return + } + + decoder, err := decoder.NewZipPluginDecoder(pkg) + if err != nil { + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + task.Status = models.InstallTaskStatusFailed + plugin.Status = models.InstallTaskStatusFailed + plugin.Message = err.Error() + }) + return + } + + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + plugin.Status = models.InstallTaskStatusRunning + plugin.Message = "Installing" + }) + + stream, err := manager.Install(tenant_id, decoder, source, meta) + if err != nil { + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + task.Status = models.InstallTaskStatusFailed + plugin.Status = models.InstallTaskStatusFailed + plugin.Message = err.Error() + }) + return + } + + for stream.Next() { + message, err := stream.Read() + if err != nil { + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + task.Status = models.InstallTaskStatusFailed + plugin.Status = models.InstallTaskStatusFailed + plugin.Message = err.Error() + }) + return + } + + if message.Event == plugin_manager.PluginInstallEventError { + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + task.Status = models.InstallTaskStatusFailed + plugin.Status = models.InstallTaskStatusFailed + plugin.Message = message.Data + }) + return + } + } + + updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) { + plugin.Status = models.InstallTaskStatusSuccess + plugin.Message = "Installed" + task.CompletedPlugins++ + + // check if all plugins are installed + if task.CompletedPlugins == task.TotalPlugins { + task.Status = models.InstallTaskStatusSuccess + } + }) + }) + } + + // submit async tasks + routine.WithMaxRoutine(3, tasks, func() { + time.AfterFunc(time.Second*5, func() { + // get task + task, err := db.GetOne[models.InstallTask]( + db.Equal("id", task.ID), + ) + if err != nil { + return + } + + if task.CompletedPlugins == task.TotalPlugins { + // delete task if all plugins are installed successfully + db.Delete(&task) + } + }) + }) + + return entities.NewSuccessResponse(response) } func FetchPluginInstallationTasks( @@ -72,14 +257,31 @@ func FetchPluginInstallationTasks( page int, page_size int, ) *entities.Response { - return nil + tasks, err := db.GetAll[models.InstallTask]( + db.Equal("tenant_id", tenant_id), + db.OrderBy("created_at", true), + db.Page(page, page_size), + ) + if err != nil { + return entities.NewErrorResponse(-500, err.Error()) + } + + return entities.NewSuccessResponse(tasks) } func FetchPluginInstallationTask( tenant_id string, task_id string, ) *entities.Response { - return nil + task, err := db.GetOne[models.InstallTask]( + db.Equal("id", task_id), + db.Equal("tenant_id", tenant_id), + ) + if err != nil { + return entities.NewErrorResponse(-500, err.Error()) + } + + return entities.NewSuccessResponse(task) } func FetchPluginManifest( @@ -89,37 +291,6 @@ func FetchPluginManifest( return nil } -func InstallPluginFromIdentifier( - tenant_id string, - plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, - source string, - meta map[string]any, -) error { - // TODO: refactor - // check if identifier exists - plugin, err := db.GetOne[models.Plugin]( - db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()), - ) - if err == db.ErrDatabaseNotFound { - return errors.New("plugin not found") - } - if err != nil { - return err - } - - if plugin.InstallType == plugin_entities.PLUGIN_RUNTIME_TYPE_REMOTE { - return errors.New("remote plugin not supported") - } - - declaration := plugin.Declaration - // install to this workspace - if _, _, err := curd.InstallPlugin(tenant_id, plugin_unique_identifier, plugin.InstallType, &declaration, source, meta); err != nil { - return err - } - - return nil -} - func FetchPluginFromIdentifier( plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, ) *entities.Response { diff --git a/internal/types/app/config.go b/internal/types/app/config.go index fc15a27..072acf0 100644 --- a/internal/types/app/config.go +++ b/internal/types/app/config.go @@ -24,11 +24,12 @@ type Config struct { PluginEndpointEnabled bool `envconfig:"PLUGIN_ENDPOINT_ENABLED"` - PluginStoragePath string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"` - PluginWorkingPath string `envconfig:"PLUGIN_WORKING_PATH"` - PluginMediaCacheSize uint16 `envconfig:"PLUGIN_MEDIA_CACHE_SIZE"` - PluginMediaCachePath string `envconfig:"PLUGIN_MEDIA_CACHE_PATH"` - ProcessCachingPath string `envconfig:"PROCESS_CACHING_PATH"` + PluginStoragePath string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"` + PluginPackageCachePath string `envconfig:"PLUGIN_PACKAGE_CACHE_PATH"` + PluginWorkingPath string `envconfig:"PLUGIN_WORKING_PATH"` + PluginMediaCacheSize uint16 `envconfig:"PLUGIN_MEDIA_CACHE_SIZE"` + PluginMediaCachePath string `envconfig:"PLUGIN_MEDIA_CACHE_PATH"` + ProcessCachingPath string `envconfig:"PROCESS_CACHING_PATH"` PluginMaxExecutionTimeout int `envconfig:"PLUGIN_MAX_EXECUTION_TIMEOUT" validate:"required"` @@ -128,10 +129,14 @@ func (c *Config) Validate() error { c.PersistenceStorageS3AccessKey == "" || c.PersistenceStorageS3SecretKey == "" || c.PersistenceStorageS3Bucket == "" { - return fmt.Errorf("s3 region, access key, secret key, bucket is empty") + return fmt.Errorf("s3 region, access key, secret key or bucket is empty") } } + if c.PluginPackageCachePath == "" { + return fmt.Errorf("plugin package cache path is empty") + } + return nil } diff --git a/internal/types/app/default.go b/internal/types/app/default.go index 08ecdf2..d78425e 100644 --- a/internal/types/app/default.go +++ b/internal/types/app/default.go @@ -21,6 +21,7 @@ func (config *Config) SetDefault() { setDefaultString(&config.PluginMediaCachePath, "./storage/assets") setDefaultString(&config.PersistenceStorageLocalPath, "./storage/persistence") setDefaultString(&config.ProcessCachingPath, "./storage/subprocesses") + setDefaultString(&config.PluginPackageCachePath, "./storage/plugin_packages") } func setDefaultInt[T constraints.Integer](value *T, defaultValue T) { diff --git a/internal/types/models/task.go b/internal/types/models/task.go new file mode 100644 index 0000000..cb681d1 --- /dev/null +++ b/internal/types/models/task.go @@ -0,0 +1,27 @@ +package models + +import "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities" + +type InstallTaskStatus string + +const ( + InstallTaskStatusPending InstallTaskStatus = "pending" + InstallTaskStatusRunning InstallTaskStatus = "running" + InstallTaskStatusSuccess InstallTaskStatus = "success" + InstallTaskStatusFailed InstallTaskStatus = "failed" +) + +type InstallTaskPluginStatus struct { + PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier"` + PluginID string `json:"plugin_id"` + Status InstallTaskStatus `json:"status"` + Message string `json:"message"` +} + +type InstallTask struct { + Model + Status InstallTaskStatus `json:"status" gorm:"not null"` + TotalPlugins int `json:"total_plugins" gorm:"not null"` + CompletedPlugins int `json:"completed_plugins" gorm:"not null"` + Plugins []InstallTaskPluginStatus `json:"plugins" gorm:"serializer:json"` +} diff --git a/internal/utils/routine/pool.go b/internal/utils/routine/pool.go index a34b9a3..382d467 100644 --- a/internal/utils/routine/pool.go +++ b/internal/utils/routine/pool.go @@ -2,6 +2,7 @@ package routine import ( "sync" + "sync/atomic" "github.com/langgenius/dify-plugin-daemon/internal/utils/log" "github.com/panjf2000/ants" @@ -31,3 +32,42 @@ func InitPool(size int) { func Submit(f func()) { p.Submit(f) } + +func WithMaxRoutine(max_routine int, tasks []func(), on_finish ...func()) { + if max_routine <= 0 { + max_routine = 1 + } + + if max_routine > len(tasks) { + max_routine = len(tasks) + } + + Submit(func() { + wg := sync.WaitGroup{} + task_index := int32(0) + + for i := 0; i < max_routine; i++ { + wg.Add(1) + Submit(func() { + defer wg.Done() + current_index := atomic.AddInt32(&task_index, 1) + + if current_index >= int32(len(tasks)) { + return + } + + for current_index < int32(len(tasks)) { + task := tasks[current_index] + task() + current_index++ + } + }) + } + + wg.Wait() + + if len(on_finish) > 0 { + on_finish[0]() + } + }) +}