Skip to content

Commit

Permalink
feat: refactor aws
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Aug 25, 2024
1 parent bbc84bc commit 73cdc94
Show file tree
Hide file tree
Showing 23 changed files with 444 additions and 352 deletions.
1 change: 1 addition & 0 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func setDefault(config *app.Config) {
setDefaultInt(&config.PluginRemoteInstallingMaxConn, 128)
setDefaultInt(&config.MaxPluginPackageSize, 52428800)
setDefaultInt(&config.MaxAWSLambdaTransactionTimeout, 150)
setDefaultInt(&config.PluginMaxExecutionTimeout, 240)
setDefaultBool(&config.PluginRemoteInstallingEnabled, true)
setDefaultBool(&config.PluginWebhookEnabled, true)
setDefaultString(&config.DBSslMode, "disable")
Expand Down
4 changes: 2 additions & 2 deletions internal/core/plugin_daemon/backwards_invocation/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ func (bi *BackwardsInvocation) TenantID() (string, error) {
if bi.session == nil {
return "", fmt.Errorf("session is nil")
}
return bi.session.TenantID(), nil
return bi.session.TenantID, nil
}

func (bi *BackwardsInvocation) UserID() (string, error) {
if bi.session == nil {
return "", fmt.Errorf("session is nil")
}
return bi.session.UserID(), nil
return bi.session.UserID, nil
}
40 changes: 20 additions & 20 deletions internal/core/plugin_daemon/backwards_invocation/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

func InvokeDify(
runtime plugin_entities.PluginRuntimeInterface,
declaration *plugin_entities.PluginDeclaration,
invoke_from access_types.PluginAccessType,
session *session_manager.Session,
writer BackwardsInvocationWriter,
Expand Down Expand Up @@ -43,7 +43,7 @@ func InvokeDify(
}

// check permission
if err := checkPermission(runtime, request_handle); err != nil {
if err := checkPermission(declaration, request_handle); err != nil {
request_handle.WriteError(err)
request_handle.EndResponse()
return nil
Expand All @@ -61,63 +61,63 @@ func InvokeDify(
var (
permissionMapping = map[dify_invocation.InvokeType]map[string]any{
dify_invocation.INVOKE_TYPE_TOOL: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeTool()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeTool()
},
"error": "permission denied, you need to enable tool access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_LLM: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeLLM()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeLLM()
},
"error": "permission denied, you need to enable llm access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeTextEmbedding()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeTextEmbedding()
},
"error": "permission denied, you need to enable text-embedding access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_RERANK: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeRerank()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeRerank()
},
"error": "permission denied, you need to enable rerank access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_TTS: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeTTS()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeTTS()
},
"error": "permission denied, you need to enable tts access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_SPEECH2TEXT: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeSpeech2Text()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeSpeech2Text()
},
"error": "permission denied, you need to enable speech2text access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_MODERATION: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeModeration()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeModeration()
},
"error": "permission denied, you need to enable moderation access in plugin manifest",
},
dify_invocation.INVOKE_TYPE_NODE: {
"func": func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool {
return runtime.Configuration().Resource.Permission.AllowInvokeNode()
"func": func(declaration *plugin_entities.PluginDeclaration) bool {
return declaration.Resource.Permission.AllowInvokeNode()
},
"error": "permission denied, you need to enable node access in plugin manifest",
},
}
)

func checkPermission(runtime plugin_entities.PluginRuntimeTimeLifeInterface, request_handle *BackwardsInvocation) error {
func checkPermission(runtime *plugin_entities.PluginDeclaration, request_handle *BackwardsInvocation) error {
permission, ok := permissionMapping[request_handle.Type()]
if !ok {
return fmt.Errorf("unsupported invoke type: %s", request_handle.Type())
}

permission_func, ok := permission["func"].(func(runtime plugin_entities.PluginRuntimeTimeLifeInterface) bool)
permission_func, ok := permission["func"].(func(runtime *plugin_entities.PluginDeclaration) bool)
if !ok {
return fmt.Errorf("permission function not found: %s", request_handle.Type())
}
Expand Down
46 changes: 19 additions & 27 deletions internal/core/plugin_daemon/backwards_invocation/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,23 @@ func (r *TPluginRuntime) Wait() (<-chan bool, error) {
}

func TestBackwardsInvocationAllPermittedPermission(t *testing.T) {
all_permitted_runtime := TPluginRuntime{
PluginRuntime: plugin_entities.PluginRuntime{
Config: plugin_entities.PluginDeclaration{
Resource: plugin_entities.PluginResourceRequirement{
Permission: &plugin_entities.PluginPermissionRequirement{
Tool: &plugin_entities.PluginPermissionToolRequirement{
Enabled: true,
},
Model: &plugin_entities.PluginPermissionModelRequirement{
Enabled: true,
LLM: true,
TextEmbedding: true,
Rerank: true,
Moderation: true,
TTS: true,
Speech2text: true,
},
Node: &plugin_entities.PluginPermissionNodeRequirement{
Enabled: true,
},
},
all_permitted_runtime := plugin_entities.PluginDeclaration{
Resource: plugin_entities.PluginResourceRequirement{
Permission: &plugin_entities.PluginPermissionRequirement{
Tool: &plugin_entities.PluginPermissionToolRequirement{
Enabled: true,
},
Model: &plugin_entities.PluginPermissionModelRequirement{
Enabled: true,
LLM: true,
TextEmbedding: true,
Rerank: true,
Moderation: true,
TTS: true,
Speech2text: true,
},
Node: &plugin_entities.PluginPermissionNodeRequirement{
Enabled: true,
},
},
},
Expand Down Expand Up @@ -106,12 +102,8 @@ func TestBackwardsInvocationAllPermittedPermission(t *testing.T) {
}

func TestBackwardsInvocationAllDeniedPermission(t *testing.T) {
all_denied_runtime := TPluginRuntime{
PluginRuntime: plugin_entities.PluginRuntime{
Config: plugin_entities.PluginDeclaration{
Resource: plugin_entities.PluginResourceRequirement{},
},
},
all_denied_runtime := plugin_entities.PluginDeclaration{
Resource: plugin_entities.PluginResourceRequirement{},
}

invoke_llm_request := NewBackwardsInvocation(dify_invocation.INVOKE_TYPE_LLM, "", nil, nil, nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/aws_manager"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
Expand Down Expand Up @@ -39,7 +40,6 @@ func (w *awsTransactionWriteCloser) Close() error {
func (h *AWSTransactionHandler) Handle(
ctx *gin.Context,
session_id string,
runtime *aws_manager.AWSPluginRuntime,
) {
writer := &awsTransactionWriteCloser{
ResponseWriter: ctx.Writer,
Expand Down Expand Up @@ -67,17 +67,26 @@ func (h *AWSTransactionHandler) Handle(
return
}

data.RuntimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
data.SessionWriter = writer

// send the data to the plugin runtime
if err := runtime.PushRequest(session_id, data); err != nil {
log.Error("push request failed: %s", err.Error())
session := session_manager.GetSession(session_id)
if err != nil {
log.Error("get session failed: %s", err.Error())
writer.WriteHeader(http.StatusInternalServerError)
writer.Write([]byte(err.Error()))
return
}

aws_response_writer := NewAWSTransactionWriter(session, writer)

if err := backwards_invocation.InvokeDify(
session.Declaration,
session.InvokeFrom,
session,
aws_response_writer,
data.Data,
); err != nil {
log.Error("invoke dify failed: %s", err.Error())
}

select {
case <-writer.done:
return
Expand Down
30 changes: 14 additions & 16 deletions internal/core/plugin_daemon/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package plugin_daemon
import (
"errors"

"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
Expand All @@ -18,17 +17,15 @@ func genericInvokePlugin[Req any, Rsp any](
session *session_manager.Session,
request *Req,
response_buffer_size int,
typ access_types.PluginAccessType,
action access_types.PluginAccessAction,
) (*stream.StreamResponse[Rsp], error) {
runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginIdentity())
runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginIdentity)
if runtime == nil {
return nil, errors.New("plugin not found")
}

response := stream.NewStreamResponse[Rsp](response_buffer_size)

listener := runtime.Listen(session.ID())
listener := runtime.Listen(session.ID)
listener.Listen(func(chunk plugin_entities.SessionMessage) {
switch chunk.Type {
case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
Expand All @@ -41,13 +38,18 @@ func genericInvokePlugin[Req any, Rsp any](
}
case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
// check if the request contains a aws_event_id
var writer backwards_invocation.BackwardsInvocationWriter
if chunk.RuntimeType == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
writer = transaction.NewAWSTransactionWriter(session, chunk.SessionWriter)
} else {
writer = transaction.NewFullDuplexEventWriter(session)
if runtime.Type() == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
response.WriteError(errors.New("aws event is not supported by full duplex"))
response.Close()
return
}
if err := backwards_invocation.InvokeDify(runtime, typ, session, writer, chunk.Data); err != nil {
if err := backwards_invocation.InvokeDify(
runtime.Configuration(),
session.InvokeFrom,
session,
transaction.NewFullDuplexEventWriter(session),
chunk.Data,
); err != nil {
log.Error("invoke dify failed: %s", err.Error())
return
}
Expand All @@ -74,8 +76,6 @@ func genericInvokePlugin[Req any, Rsp any](
session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
getInvokePluginMap(
session,
typ,
action,
request,
),
)
Expand All @@ -85,11 +85,9 @@ func genericInvokePlugin[Req any, Rsp any](

func getInvokePluginMap(
session *session_manager.Session,
typ access_types.PluginAccessType,
action access_types.PluginAccessAction,
request any,
) map[string]any {
req := getBasicPluginAccessMap(session.UserID(), typ, action)
req := getBasicPluginAccessMap(session.UserID, session.InvokeFrom, session.Action)
for k, v := range parser.StructToMap(request) {
req[k] = v
}
Expand Down
17 changes: 0 additions & 17 deletions internal/core/plugin_daemon/model_service.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package plugin_daemon

import (
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
Expand All @@ -18,8 +17,6 @@ func InvokeLLM(
session,
request,
512,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
)
}

Expand All @@ -33,8 +30,6 @@ func InvokeTextEmbedding(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
)
}

Expand All @@ -48,8 +43,6 @@ func InvokeRerank(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
)
}

Expand All @@ -63,8 +56,6 @@ func InvokeTTS(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
)
}

Expand All @@ -78,8 +69,6 @@ func InvokeSpeech2Text(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
)
}

Expand All @@ -93,8 +82,6 @@ func InvokeModeration(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
)
}

Expand All @@ -108,8 +95,6 @@ func ValidateProviderCredentials(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
)
}

Expand All @@ -123,7 +108,5 @@ func ValidateModelCredentials(
session,
request,
1,
access_types.PLUGIN_ACCESS_TYPE_MODEL,
access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
)
}
Loading

0 comments on commit 73cdc94

Please sign in to comment.