Skip to content

Commit

Permalink
RetrievalModel 参数整体传入 & 添加Reranking相关参数
Browse files Browse the repository at this point in the history
  • Loading branch information
wangle201210 committed Jan 23, 2025
1 parent 9f8057a commit 7b7d744
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 119 deletions.
24 changes: 16 additions & 8 deletions components/retriever/dify/dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,26 @@ const (
keywordsKey = "keywords"
)

type retrievalModel struct {
SearchMethod SearchMethod `json:"search_method"`
Weights float64 `json:"weights"`
TopK *int `json:"top_k"`
ScoreThresholdEnabled bool `json:"score_threshold_enabled"`
ScoreThreshold *float64 `json:"score_threshold"`
type RetrievalModel struct {
SearchMethod SearchMethod `json:"search_method"`
RerankingEnable *bool `json:"reranking_enable"`
RerankingMode *string `json:"reranking_mode"`
RerankingModel *RerankingModel `json:"reranking_model"`
Weights *float64 `json:"weights"`
TopK *int `json:"top_k,omitempty"`
ScoreThresholdEnabled *bool `json:"score_threshold_enabled"`
ScoreThreshold *float64 `json:"score_threshold"`
}

type RerankingModel struct {
RerankingProviderName string `json:"reranking_provider_name"`
RerankingModelName string `json:"reranking_model_name"`
}

// request Body
type request struct {
Query string `json:"query"`
RetrievalModel *retrievalModel `json:"retrieval_model,omitempty"`
RetrievalModel *RetrievalModel `json:"retrieval_model,omitempty"`
}

type errorResponse struct {
Expand Down Expand Up @@ -105,7 +113,7 @@ func (r *Retriever) getAuth() string {
func (r *Retriever) doPost(ctx context.Context, query string) (res *successResponse, err error) {
data := &request{
Query: query,
RetrievalModel: r.retrievalModel,
RetrievalModel: r.config.RetrievalModel,
}

reqData, err := sonic.MarshalString(data)
Expand Down
17 changes: 10 additions & 7 deletions components/retriever/dify/examples/dify.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import (

var (
Endpoint = "https://api.dify.ai/v1"
APIKey = "dataset-api-key"
DatasetID = "dataset-id"
APIKey = "dataset-dQ9RIuhs0YA5BgZnDhaIUuD2"
DatasetID = "d5e591c2-f8dc-4896-8fc6-288413047ff1"
)

func main() {
Expand All @@ -48,7 +48,6 @@ func basicExample() {
APIKey: APIKey,
Endpoint: Endpoint,
DatasetID: DatasetID,
TopK: ptrOf(5), // 返回前5个最相关的文档
})
if err != nil {
log.Fatalf("Failed to create retriever: %v", err)
Expand All @@ -74,10 +73,14 @@ func scoreThresholdExample() {
// 创建带有分数阈值的 Dify Retriever
threshold := 0.7 // 设置相关性分数阈值
ret, err := dify.NewRetriever(ctx, &dify.RetrieverConfig{
APIKey: APIKey,
Endpoint: Endpoint,
DatasetID: DatasetID,
ScoreThreshold: &threshold,
APIKey: APIKey,
Endpoint: Endpoint,
DatasetID: DatasetID,
RetrievalModel: &dify.RetrievalModel{
SearchMethod: dify.SearchMethodHybrid,
TopK: ptrOf(10),
ScoreThreshold: ptrOf(threshold),
},
})
if err != nil {
log.Fatalf("Failed to create retriever: %v", err)
Expand Down
58 changes: 20 additions & 38 deletions components/retriever/dify/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,20 @@ import (
// RetrieverConfig 定义了 Dify Retriever 的配置参数
type RetrieverConfig struct {
// APIKey 是 Dify API 的认证密钥
APIKey string `json:"api_key"`
APIKey string
// Endpoint 是 Dify API 的服务地址
Endpoint string `json:"endpoint"`
Endpoint string
// DatasetID 是知识库的唯一标识
DatasetID string `json:"dataset_id"`
// ScoreThreshold 是文档相关性评分的阈值
ScoreThreshold *float64 `json:"score_threshold,omitempty"`
// *retrievalModel `json:"retrieval_model,omitempty"`
// TopK 定义了返回结果的最大数量
TopK *int `json:"top_k,omitempty"`
DatasetID string
// RetrievalModel 检索参数 选填,如不填,按照默认方式召回
RetrievalModel *RetrievalModel
// Timeout 定义了 HTTP 连接超时时间
Timeout time.Duration `json:"timeout,omitempty"`
// SearchMethod 检索方法,如果不填则按照默认方式召回
SearchMethod SearchMethod `json:"search_method"`
// Weights 混合检索模式下语意检索的权重设置
Weights float64 `json:"weights"`
// ScoreThresholdEnabled 是否开启 score 阈值
ScoreThresholdEnabled bool `json:"score_threshold_enabled"`
Timeout time.Duration `json:"timeout"`
}

type Retriever struct {
config *RetrieverConfig
client *http.Client
retrievalModel *retrievalModel
config *RetrieverConfig
client *http.Client
}

func NewRetriever(ctx context.Context, config *RetrieverConfig) (*Retriever, error) {
Expand All @@ -67,6 +57,10 @@ func NewRetriever(ctx context.Context, config *RetrieverConfig) (*Retriever, err
return nil, fmt.Errorf("dataset_id is required")
}

if config.RetrievalModel != nil && config.RetrievalModel.SearchMethod == "" {
return nil, fmt.Errorf("if retrieval_model is set, search_method is required")
}

if config.Endpoint == "" {
config.Endpoint = defaultEndpoint
}
Expand All @@ -75,9 +69,8 @@ func NewRetriever(ctx context.Context, config *RetrieverConfig) (*Retriever, err
httpClient.Timeout = config.Timeout
}
return &Retriever{
config: config,
client: httpClient,
retrievalModel: config.toRetrievalModel(),
config: config,
client: httpClient,
}, nil
}

Expand All @@ -91,10 +84,12 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve
}()

// 合并检索选项
options := retriever.GetCommonOptions(&retriever.Options{
TopK: r.config.TopK,
ScoreThreshold: r.config.ScoreThreshold,
}, opts...)
baseOptions := &retriever.Options{}
if r.config.RetrievalModel != nil {
baseOptions.TopK = r.config.RetrievalModel.TopK
baseOptions.ScoreThreshold = r.config.RetrievalModel.ScoreThreshold
}
options := retriever.GetCommonOptions(baseOptions, opts...)

// 开始检索回调
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Expand Down Expand Up @@ -136,16 +131,3 @@ func (r *Retriever) GetType() string {
func (r *Retriever) IsCallbacksEnabled() bool {
return true
}

func (x *RetrieverConfig) toRetrievalModel() *retrievalModel {
if x == nil {
return nil
}
return &retrievalModel{
SearchMethod: x.SearchMethod,
Weights: x.Weights,
TopK: x.TopK,
ScoreThresholdEnabled: x.ScoreThresholdEnabled,
ScoreThreshold: x.ScoreThreshold,
}
}
104 changes: 38 additions & 66 deletions components/retriever/dify/retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ func TestRetrieve(t *testing.T) {
APIKey: "test",
Endpoint: "https://api.dify.ai/v1",
DatasetID: "test",
TopK: ptrOf(10),
},
client: &http.Client{},
}
Expand Down Expand Up @@ -171,53 +170,54 @@ func TestRetrieve(t *testing.T) {
})
}

func TestNewRetrieverWithSearchMethod(t *testing.T) {
PatchConvey("test NewRetriever with search method", t, func() {
func TestNewRetrieverWithRetrievalModel(t *testing.T) {
PatchConvey("test NewRetriever with retrieval model", t, func() {
ctx := context.Background()

PatchConvey("test full text search", func() {
ret, err := NewRetriever(ctx, &RetrieverConfig{
APIKey: "test",
Endpoint: "https://api.dify.ai/v1",
DatasetID: "test",
SearchMethod: SearchMethodFullText,
})
convey.So(err, convey.ShouldBeNil)
convey.So(ret, convey.ShouldNotBeNil)
convey.So(ret.retrievalModel, convey.ShouldNotBeNil)
convey.So(ret.retrievalModel.SearchMethod, convey.ShouldEqual, SearchMethodFullText)
})

PatchConvey("test hybrid search with weights", func() {
ret, err := NewRetriever(ctx, &RetrieverConfig{
APIKey: "test",
Endpoint: "https://api.dify.ai/v1",
DatasetID: "test",
SearchMethod: SearchMethodHybrid,
Weights: 0.7,
PatchConvey("test retrieval model validation", func() {
PatchConvey("test empty search method", func() {
ret, err := NewRetriever(ctx, &RetrieverConfig{
APIKey: "test",
Endpoint: "https://api.dify.ai/v1",
DatasetID: "test",
RetrievalModel: &RetrievalModel{},
})
convey.So(err, convey.ShouldNotBeNil)
convey.So(err.Error(), convey.ShouldContainSubstring, "search_method is required")
convey.So(ret, convey.ShouldBeNil)
})
convey.So(err, convey.ShouldBeNil)
convey.So(ret, convey.ShouldNotBeNil)
convey.So(ret.retrievalModel, convey.ShouldNotBeNil)
convey.So(ret.retrievalModel.SearchMethod, convey.ShouldEqual, SearchMethodHybrid)
convey.So(ret.retrievalModel.Weights, convey.ShouldEqual, 0.7)
})

PatchConvey("test with score threshold", func() {
PatchConvey("test with valid retrieval model", func() {
threshold := 0.8
ret, err := NewRetriever(ctx, &RetrieverConfig{
APIKey: "test",
Endpoint: "https://api.dify.ai/v1",
DatasetID: "test",
SearchMethod: SearchMethodFullText,
ScoreThreshold: &threshold,
ScoreThresholdEnabled: true,
APIKey: "test",
Endpoint: "https://api.dify.ai/v1",
DatasetID: "test",
RetrievalModel: &RetrievalModel{
SearchMethod: SearchMethodSemantic,
RerankingEnable: ptrOf(true),
RerankingMode: ptrOf("hybrid"),
Weights: ptrOf(0.7),
TopK: ptrOf(10),
ScoreThreshold: &threshold,
ScoreThresholdEnabled: ptrOf(true),
RerankingModel: &RerankingModel{
RerankingProviderName: "openai",
RerankingModelName: "gpt-3.5-turbo",
},
},
})
convey.So(err, convey.ShouldBeNil)
convey.So(ret, convey.ShouldNotBeNil)
convey.So(ret.retrievalModel, convey.ShouldNotBeNil)
convey.So(ret.retrievalModel.ScoreThresholdEnabled, convey.ShouldBeTrue)
convey.So(*ret.retrievalModel.ScoreThreshold, convey.ShouldEqual, threshold)
convey.So(ret.config.RetrievalModel, convey.ShouldNotBeNil)
convey.So(ret.config.RetrievalModel.SearchMethod, convey.ShouldEqual, SearchMethodSemantic)
convey.So(*ret.config.RetrievalModel.RerankingMode, convey.ShouldEqual, "hybrid")
convey.So(*ret.config.RetrievalModel.Weights, convey.ShouldEqual, 0.7)
convey.So(*ret.config.RetrievalModel.TopK, convey.ShouldEqual, 10)
convey.So(*ret.config.RetrievalModel.ScoreThreshold, convey.ShouldEqual, threshold)
convey.So(ret.config.RetrievalModel.RerankingModel.RerankingProviderName, convey.ShouldEqual, "openai")
convey.So(ret.config.RetrievalModel.RerankingModel.RerankingModelName, convey.ShouldEqual, "gpt-3.5-turbo")
})
})
}
Expand All @@ -235,31 +235,3 @@ func TestIsCallbacksEnabled(t *testing.T) {
convey.So(r.IsCallbacksEnabled(), convey.ShouldBeTrue)
})
}

func TestToRetrievalModel(t *testing.T) {
PatchConvey("test toRetrievalModel", t, func() {
PatchConvey("test nil config", func() {
var config *RetrieverConfig
model := config.toRetrievalModel()
convey.So(model, convey.ShouldBeNil)
})

PatchConvey("test with search method", func() {
threshold := 0.8
config := &RetrieverConfig{
SearchMethod: SearchMethodFullText,
Weights: 0.7,
TopK: ptrOf(10),
ScoreThreshold: &threshold,
ScoreThresholdEnabled: true,
}
model := config.toRetrievalModel()
convey.So(model, convey.ShouldNotBeNil)
convey.So(model.SearchMethod, convey.ShouldEqual, SearchMethodFullText)
convey.So(model.Weights, convey.ShouldEqual, 0.7)
convey.So(*model.TopK, convey.ShouldEqual, 10)
convey.So(*model.ScoreThreshold, convey.ShouldEqual, threshold)
convey.So(model.ScoreThresholdEnabled, convey.ShouldBeTrue)
})
})
}

0 comments on commit 7b7d744

Please sign in to comment.