Skip to content

Commit

Permalink
refactor: installing plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Oct 14, 2024
1 parent a74a1a4 commit 3de55a8
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 43 deletions.
12 changes: 12 additions & 0 deletions internal/core/plugin_manager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type PluginManager struct {

maxPluginPackageSize int64
workingDirectory string
packageCachePath string

// mediaManager is used to manage media files like plugin icons, images, etc.
mediaManager *media_manager.MediaManager
Expand Down Expand Up @@ -55,6 +56,7 @@ var (
func NewManager(configuration *app.Config) *PluginManager {
manager = &PluginManager{
maxPluginPackageSize: configuration.MaxPluginPackageSize,
packageCachePath: configuration.PluginPackageCachePath,
workingDirectory: configuration.PluginWorkingPath,
mediaManager: media_manager.NewMediaManager(
configuration.PluginMediaCachePath,
Expand All @@ -74,6 +76,7 @@ func NewManager(configuration *app.Config) *PluginManager {
os.MkdirAll(configuration.PluginWorkingPath, 0755)
os.MkdirAll(configuration.PluginStoragePath, 0755)
os.MkdirAll(configuration.PluginMediaCachePath, 0755)
os.MkdirAll(configuration.PluginPackageCachePath, 0755)
os.MkdirAll(filepath.Dir(configuration.ProcessCachingPath), 0755)

return manager
Expand Down Expand Up @@ -146,3 +149,12 @@ func (p *PluginManager) Init(configuration *app.Config) {
func (p *PluginManager) BackwardsInvocation() dify_invocation.BackwardsInvocation {
return p.backwardsInvocation
}

func (p *PluginManager) SavePackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, pkg []byte) error {
// save to storage
return os.WriteFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()), pkg, 0644)
}

func (p *PluginManager) GetPackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) ([]byte, error) {
return os.ReadFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()))
}
1 change: 1 addition & 0 deletions internal/db/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func autoMigrate() error {
models.ServerlessRuntime{},
models.ToolInstallation{},
models.AIModelInstallation{},
models.InstallTask{},
)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/server/controllers/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func InstallPluginFromIdentifiers(app *app.Config) gin.HandlerFunc {
return func(c *gin.Context) {
BindRequest(c, func(request struct {
TenantID string `uri:"tenant_id" validate:"required"`
PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,dive,plugin_unique_identifier"`
PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,max=64,dive,plugin_unique_identifier"`
Source string `json:"source" validate:"required"`
Meta map[string]any `json:"meta" validate:"omitempty"`
}) {
Expand Down
243 changes: 207 additions & 36 deletions internal/service/install_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"fmt"
"io"
"mime/multipart"
"time"

"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/verifier"
"github.com/langgenius/dify-plugin-daemon/internal/db"
Expand All @@ -15,6 +17,9 @@ import (
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/models"
"github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
"gorm.io/gorm"
)

func UploadPluginFromPkg(
Expand Down Expand Up @@ -57,29 +62,226 @@ func InstallPluginFromIdentifiers(
source string,
meta map[string]any,
) *entities.Response {
var response struct {
AllInstalled bool `json:"all_installed"`
TaskID string `json:"task_id"`
}

// TODO: create installation task and dispatch to workers
for _, plugin_unique_identifier := range plugin_unique_identifiers {
if err := InstallPluginFromIdentifier(tenant_id, plugin_unique_identifier, source, meta); err != nil {
plugins_wait_for_installation := []plugin_entities.PluginUniqueIdentifier{}

task := &models.InstallTask{
Status: models.InstallTaskStatusRunning,
TotalPlugins: len(plugins_wait_for_installation),
CompletedPlugins: 0,
Plugins: []models.InstallTaskPluginStatus{},
}

for i, plugin_unique_identifier := range plugin_unique_identifiers {
// check if plugin is already installed
plugin, err := db.GetOne[models.Plugin](
db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
)

task.Plugins = append(task.Plugins, models.InstallTaskPluginStatus{
PluginUniqueIdentifier: plugin_unique_identifier,
PluginID: plugin_unique_identifier.PluginID(),
Status: models.InstallTaskStatusPending,
Message: "",
})

task.TotalPlugins++

if err == nil {
// already installed by other tenant
declaration := plugin.Declaration
if _, _, err := curd.InstallPlugin(
tenant_id,
plugin_unique_identifier,
plugin.InstallType,
&declaration,
source,
meta,
); err != nil {
return entities.NewErrorResponse(-500, err.Error())
}

task.CompletedPlugins++
task.Plugins[i].Status = models.InstallTaskStatusSuccess
task.Plugins[i].Message = "Installed"
continue
}

if err != db.ErrDatabaseNotFound {
return entities.NewErrorResponse(-500, err.Error())
}

plugins_wait_for_installation = append(plugins_wait_for_installation, plugin_unique_identifier)
}

return entities.NewSuccessResponse(true)
if len(plugins_wait_for_installation) == 0 {
response.AllInstalled = true
response.TaskID = ""
return entities.NewSuccessResponse(response)
}

err := db.Create(task)
if err != nil {
return entities.NewErrorResponse(-500, err.Error())
}

response.TaskID = task.ID

manager := plugin_manager.Manager()

tasks := []func(){}
for _, plugin_unique_identifier := range plugins_wait_for_installation {
tasks = append(tasks, func() {
updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) {
if err := db.WithTransaction(func(tx *gorm.DB) error {
task, err := db.GetOne[models.InstallTask](
db.WithTransactionContext(tx),
db.Equal("id", task.ID),
db.WLock(), // write lock, multiple tasks can't update the same task
)
if err != nil {
return err
}

task_pointer := &task
var plugin_status *models.InstallTaskPluginStatus
for _, plugin := range task.Plugins {
if plugin.PluginUniqueIdentifier == plugin_unique_identifier {
plugin_status = &plugin
}
}
modifier(task_pointer, plugin_status)
return db.Update(task_pointer, tx)
}); err != nil {
log.Error("failed to update install task status %s", err.Error())
}
}

pkg, err := manager.GetPackage(plugin_unique_identifier)
if err != nil {
updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
task.Status = models.InstallTaskStatusFailed
plugin.Status = models.InstallTaskStatusFailed
plugin.Message = err.Error()
})
return
}

decoder, err := decoder.NewZipPluginDecoder(pkg)
if err != nil {
updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
task.Status = models.InstallTaskStatusFailed
plugin.Status = models.InstallTaskStatusFailed
plugin.Message = err.Error()
})
return
}

updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
plugin.Status = models.InstallTaskStatusRunning
plugin.Message = "Installing"
})

stream, err := manager.Install(tenant_id, decoder, source, meta)
if err != nil {
updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
task.Status = models.InstallTaskStatusFailed
plugin.Status = models.InstallTaskStatusFailed
plugin.Message = err.Error()
})
return
}

for stream.Next() {
message, err := stream.Read()
if err != nil {
updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
task.Status = models.InstallTaskStatusFailed
plugin.Status = models.InstallTaskStatusFailed
plugin.Message = err.Error()
})
return
}

if message.Event == plugin_manager.PluginInstallEventError {
updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
task.Status = models.InstallTaskStatusFailed
plugin.Status = models.InstallTaskStatusFailed
plugin.Message = message.Data
})
return
}
}

updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
plugin.Status = models.InstallTaskStatusSuccess
plugin.Message = "Installed"
task.CompletedPlugins++

// check if all plugins are installed
if task.CompletedPlugins == task.TotalPlugins {
task.Status = models.InstallTaskStatusSuccess
}
})
})
}

// submit async tasks
routine.WithMaxRoutine(3, tasks, func() {
time.AfterFunc(time.Second*5, func() {
// get task
task, err := db.GetOne[models.InstallTask](
db.Equal("id", task.ID),
)
if err != nil {
return
}

if task.CompletedPlugins == task.TotalPlugins {
// delete task if all plugins are installed successfully
db.Delete(&task)
}
})
})

return entities.NewSuccessResponse(response)
}

func FetchPluginInstallationTasks(
tenant_id string,
page int,
page_size int,
) *entities.Response {
return nil
tasks, err := db.GetAll[models.InstallTask](
db.Equal("tenant_id", tenant_id),
db.OrderBy("created_at", true),
db.Page(page, page_size),
)
if err != nil {
return entities.NewErrorResponse(-500, err.Error())
}

return entities.NewSuccessResponse(tasks)
}

func FetchPluginInstallationTask(
tenant_id string,
task_id string,
) *entities.Response {
return nil
task, err := db.GetOne[models.InstallTask](
db.Equal("id", task_id),
db.Equal("tenant_id", tenant_id),
)
if err != nil {
return entities.NewErrorResponse(-500, err.Error())
}

return entities.NewSuccessResponse(task)
}

func FetchPluginManifest(
Expand All @@ -89,37 +291,6 @@ func FetchPluginManifest(
return nil
}

func InstallPluginFromIdentifier(
tenant_id string,
plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
source string,
meta map[string]any,
) error {
// TODO: refactor
// check if identifier exists
plugin, err := db.GetOne[models.Plugin](
db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
)
if err == db.ErrDatabaseNotFound {
return errors.New("plugin not found")
}
if err != nil {
return err
}

if plugin.InstallType == plugin_entities.PLUGIN_RUNTIME_TYPE_REMOTE {
return errors.New("remote plugin not supported")
}

declaration := plugin.Declaration
// install to this workspace
if _, _, err := curd.InstallPlugin(tenant_id, plugin_unique_identifier, plugin.InstallType, &declaration, source, meta); err != nil {
return err
}

return nil
}

func FetchPluginFromIdentifier(
plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
) *entities.Response {
Expand Down
17 changes: 11 additions & 6 deletions internal/types/app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ type Config struct {

PluginEndpointEnabled bool `envconfig:"PLUGIN_ENDPOINT_ENABLED"`

PluginStoragePath string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"`
PluginWorkingPath string `envconfig:"PLUGIN_WORKING_PATH"`
PluginMediaCacheSize uint16 `envconfig:"PLUGIN_MEDIA_CACHE_SIZE"`
PluginMediaCachePath string `envconfig:"PLUGIN_MEDIA_CACHE_PATH"`
ProcessCachingPath string `envconfig:"PROCESS_CACHING_PATH"`
PluginStoragePath string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"`
PluginPackageCachePath string `envconfig:"PLUGIN_PACKAGE_CACHE_PATH"`
PluginWorkingPath string `envconfig:"PLUGIN_WORKING_PATH"`
PluginMediaCacheSize uint16 `envconfig:"PLUGIN_MEDIA_CACHE_SIZE"`
PluginMediaCachePath string `envconfig:"PLUGIN_MEDIA_CACHE_PATH"`
ProcessCachingPath string `envconfig:"PROCESS_CACHING_PATH"`

PluginMaxExecutionTimeout int `envconfig:"PLUGIN_MAX_EXECUTION_TIMEOUT" validate:"required"`

Expand Down Expand Up @@ -128,10 +129,14 @@ func (c *Config) Validate() error {
c.PersistenceStorageS3AccessKey == "" ||
c.PersistenceStorageS3SecretKey == "" ||
c.PersistenceStorageS3Bucket == "" {
return fmt.Errorf("s3 region, access key, secret key, bucket is empty")
return fmt.Errorf("s3 region, access key, secret key or bucket is empty")
}
}

if c.PluginPackageCachePath == "" {
return fmt.Errorf("plugin package cache path is empty")
}

return nil
}

Expand Down
1 change: 1 addition & 0 deletions internal/types/app/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func (config *Config) SetDefault() {
setDefaultString(&config.PluginMediaCachePath, "./storage/assets")
setDefaultString(&config.PersistenceStorageLocalPath, "./storage/persistence")
setDefaultString(&config.ProcessCachingPath, "./storage/subprocesses")
setDefaultString(&config.PluginPackageCachePath, "./storage/plugin_packages")
}

func setDefaultInt[T constraints.Integer](value *T, defaultValue T) {
Expand Down
Loading

0 comments on commit 3de55a8

Please sign in to comment.