From 7bd438877b8ce6db907536e2df4a04d2030bd06e Mon Sep 17 00:00:00 2001 From: Squidward <56287847+Chi-Kai@users.noreply.github.com> Date: Fri, 8 Nov 2024 13:48:32 +0800 Subject: [PATCH] add textin embedding for ai-cache (#1493) --- .../extensions/ai-cache/embedding/provider.go | 17 +- .../extensions/ai-cache/embedding/textin.go | 161 ++++++++++++++++++ 2 files changed, 175 insertions(+), 3 deletions(-) create mode 100644 plugins/wasm-go/extensions/ai-cache/embedding/textin.go diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index 909edf129c..28dc2cb794 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -9,6 +9,7 @@ import ( const ( PROVIDER_TYPE_DASHSCOPE = "dashscope" + PROVIDER_TYPE_TEXTIN = "textin" ) type providerInitializer interface { @@ -19,6 +20,7 @@ type providerInitializer interface { var ( providerInitializers = map[string]providerInitializer{ PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, + PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{}, } ) @@ -38,6 +40,15 @@ type ProviderConfig struct { // @Title zh-CN 文本特征提取服务 API Key // @Description zh-CN 文本特征提取服务 API Key apiKey string + //@Title zh-CN TextIn x-ti-app-id + // @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding + textinAppId string + //@Title zh-CN TextIn x-ti-secret-code + // @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding + textinSecretCode string + //@Title zh-CN TextIn request matryoshka_dim + // @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding + textinMatryoshkaDim int // @Title zh-CN 文本特征提取服务超时时间 // @Description zh-CN 文本特征提取服务超时时间 timeout uint32 @@ -52,6 +63,9 @@ func (c *ProviderConfig) FromJson(json gjson.Result) { c.serviceHost = json.Get("serviceHost").String() c.servicePort = json.Get("servicePort").Int() c.apiKey = json.Get("apiKey").String() + c.textinAppId = json.Get("textinAppId").String() + c.textinSecretCode = json.Get("textinSecretCode").String() + c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int()) c.timeout = uint32(json.Get("timeout").Int()) c.model = json.Get("model").String() if c.timeout == 0 { @@ -63,9 +77,6 @@ func (c *ProviderConfig) Validate() error { if c.serviceName == "" { return errors.New("embedding service name is required") } - if c.apiKey == "" { - return errors.New("embedding service API key is required") - } if c.typ == "" { return errors.New("embedding service type is required") } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/textin.go b/plugins/wasm-go/extensions/ai-cache/embedding/textin.go new file mode 100644 index 0000000000..9bc474041c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/textin.go @@ -0,0 +1,161 @@ +package embedding + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" +) + +const ( + TEXTIN_DOMAIN = "api.textin.com" + TEXTIN_PORT = 443 + TEXTIN_DEFAULT_MODEL_NAME = "acge-text-embedding" + TEXTIN_ENDPOINT = "/ai/service/v1/acge_embedding" +) + +type textInProviderInitializer struct { +} + +func (t *textInProviderInitializer) ValidateConfig(config ProviderConfig) error { + if config.textinAppId == "" { + return errors.New("embedding service TextIn App ID is required") + } + if config.textinSecretCode == "" { + return errors.New("embedding service TextIn Secret Code is required") + } + if config.textinMatryoshkaDim == 0 { + return errors.New("embedding service TextIn Matryoshka Dim is required") + } + return nil +} + +func (t *textInProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { + if c.servicePort == 0 { + c.servicePort = TEXTIN_PORT + } + if c.serviceHost == "" { + c.serviceHost = TEXTIN_DOMAIN + } + return &TIProvider{ + config: c, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: c.serviceName, + Host: c.serviceHost, + Port: int64(c.servicePort), + }), + }, nil +} + +func (t *TIProvider) GetProviderType() string { + return PROVIDER_TYPE_TEXTIN +} + +type TextInResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 `json:"duration"` + Result TextInResult `json:"result"` +} + +type TextInResult struct { + Embeddings [][]float64 `json:"embedding"` + MatryoshkaDim int `json:"matryoshka_dim"` +} + +type TextInEmbeddingRequest struct { + Input []string `json:"input"` + MatryoshkaDim int `json:"matryoshka_dim"` +} + +type TIProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + + data := TextInEmbeddingRequest{ + Input: texts, + MatryoshkaDim: t.config.textinMatryoshkaDim, + } + + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("failed to marshal request data: %v", err) + return "", nil, nil, err + } + + if t.config.textinAppId == "" { + err := errors.New("textinAppId is empty") + log.Errorf("failed to construct headers: %v", err) + return "", nil, nil, err + } + if t.config.textinSecretCode == "" { + err := errors.New("textinSecretCode is empty") + log.Errorf("failed to construct headers: %v", err) + return "", nil, nil, err + } + + headers := [][2]string{ + {"x-ti-app-id", t.config.textinAppId}, + {"x-ti-secret-code", t.config.textinSecretCode}, + {"Content-Type", "application/json"}, + } + + return TEXTIN_ENDPOINT, headers, requestBody, err +} + +func (t *TIProvider) parseTextEmbedding(responseBody []byte) (*TextInResponse, error) { + var resp TextInResponse + err := json.Unmarshal(responseBody, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (t *TIProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, err error)) error { + embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log) + if err != nil { + log.Errorf("failed to construct parameters: %v", err) + return err + } + + var resp *TextInResponse + err = t.client.Post(embUrl, embHeaders, embRequestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + + if statusCode != http.StatusOK { + err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode)) + callback(nil, err) + return + } + + log.Debugf("get embedding response: %d, %s", statusCode, responseBody) + + resp, err = t.parseTextEmbedding(responseBody) + if err != nil { + err = fmt.Errorf("failed to parse response: %v", err) + callback(nil, err) + return + } + + if len(resp.Result.Embeddings) == 0 { + err = errors.New("no embedding found in response") + callback(nil, err) + return + } + + callback(resp.Result.Embeddings[0], nil) + + }, t.config.timeout) + return err +}