diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index a44d0dc..82063ff 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -52,6 +52,10 @@ type RetrieverConfig struct { Embedding embedding.Embedder } +type SearchMode interface { // nolint: byted_s_interface_name + BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) +} + type Retriever struct { client *elasticsearch.TypedClient config *RetrieverConfig diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index e497434..163c332 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -26,13 +26,14 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf) // knn: https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html // rrf: https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html -func SearchModeApproximate(config *ApproximateConfig) SearchMode { +func SearchModeApproximate(config *ApproximateConfig) es8.SearchMode { return &approximate{config} } @@ -85,7 +86,15 @@ type approximate struct { config *ApproximateConfig } -func (a *approximate) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + var appReq ApproximateQuery if err := json.Unmarshal([]byte(query), &appReq); err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] parse query failed, %w", err) diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 05c6ea5..634a4ba 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/eino/components/embedding" "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) @@ -78,7 +79,8 @@ func TestSearchModeApproximate(t *testing.T) { sq, err := aq.ToRetrieverQuery() convey.So(err, convey.ShouldBeNil) - req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil}) + conf := &es8.RetrieverConfig{} + req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -104,7 +106,8 @@ func TestSearchModeApproximate(t *testing.T) { sq, err := aq.ToRetrieverQuery() convey.So(err, convey.ShouldBeNil) - req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}}) + conf := &es8.RetrieverConfig{} + req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}})) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) @@ -136,11 +139,11 @@ func TestSearchModeApproximate(t *testing.T) { sq, err := aq.ToRetrieverQuery() convey.So(err, convey.ShouldBeNil) - req, err := a.BuildRequest(ctx, sq, &retriever.Options{ - Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, - TopK: of(10), - ScoreThreshold: of(1.1), - }) + + conf := &es8.RetrieverConfig{} + req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + retriever.WithTopK(10), + retriever.WithScoreThreshold(1.1)) convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 460e764..07b54d5 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -26,12 +26,13 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query // see: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html#vector-functions -func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) SearchMode { +func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) es8.SearchMode { return &denseVectorSimilarity{script: denseVectorScriptMap[typ]} } @@ -54,7 +55,16 @@ type denseVectorSimilarity struct { script string } -func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + var dq DenseVectorSimilarityQuery if err := json.Unmarshal([]byte(query), &dq); err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] parse query failed, %w", err) diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index ea13b57..648dd05 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -28,6 +28,7 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) @@ -66,13 +67,16 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { sq, _ := dq.ToRetrieverQuery() PatchConvey("test embedding not provided", func() { - req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil}) + + conf := &es8.RetrieverConfig{} + req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil)) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided") convey.So(req, convey.ShouldBeNil) }) PatchConvey("test vector size invalid", func() { - req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}}}) + conf := &es8.RetrieverConfig{} + req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}})) convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=2") convey.So(req, convey.ShouldBeNil) }) @@ -87,11 +91,12 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) { for typ, exp := range typ2Exp { similarity := &denseVectorSimilarity{script: denseVectorScriptMap[typ]} - req, err := similarity.BuildRequest(ctx, sq, &retriever.Options{ - Embedding: mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}, - TopK: of(10), - ScoreThreshold: of(1.1), - }) + + conf := &es8.RetrieverConfig{} + req, err := similarity.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}), + retriever.WithTopK(10), + retriever.WithScoreThreshold(1.1)) + convey.So(err, convey.ShouldBeNil) b, err := json.Marshal(req) convey.So(err, convey.ShouldBeNil) diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index d8041f5..8f5f731 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -24,16 +24,26 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) -func SearchModeExactMatch() SearchMode { +func SearchModeExactMatch() es8.SearchMode { return &exactMatch{} } type exactMatch struct{} -func (e exactMatch) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + q := &types.Query{ Match: map[string]types.MatchQuery{ field_mapping.DocFieldNameContent: {Query: query}, diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index da6c00f..674d0a2 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -22,15 +22,19 @@ import ( "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" "github.com/cloudwego/eino/components/retriever" + + "github.com/cloudwego/eino-ext/components/retriever/es8" ) -func SearchModeRawStringRequest() SearchMode { +func SearchModeRawStringRequest() es8.SearchMode { return &rawString{} } type rawString struct{} -func (r rawString) BuildRequest(_ context.Context, query string, _ *retriever.Options) (*search.Request, error) { +func (r rawString) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + req, err := search.NewRequest().FromJSON(query) if err != nil { return nil, err diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index ab08596..7476bb0 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -26,13 +26,14 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) // SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, // which are then used in a query against a sparse vector // see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html -func SearchModeSparseVectorTextExpansion(modelID string) SearchMode { +func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode { return &sparseVectorTextExpansion{modelID} } @@ -55,7 +56,16 @@ type sparseVectorTextExpansion struct { modelID string } -func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) { +func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, + opts ...retriever.Option) (*search.Request, error) { + + options := retriever.GetCommonOptions(&retriever.Options{ + Index: &conf.Index, + TopK: &conf.TopK, + ScoreThreshold: conf.ScoreThreshold, + Embedding: conf.Embedding, + }, opts...) + var sq SparseVectorTextExpansionQuery if err := json.Unmarshal([]byte(query), &sq); err != nil { return nil, fmt.Errorf("[BuildRequest][SearchModeSparseVectorTextExpansion] parse query failed, %w", err) diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index eae9178..e59ba2e 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -27,6 +27,7 @@ import ( "github.com/cloudwego/eino/components/retriever" + "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) @@ -65,7 +66,12 @@ func TestSearchModeSparseVectorTextExpansion(t *testing.T) { } query, _ := sq.ToRetrieverQuery() - req, err := s.BuildRequest(ctx, query, &retriever.Options{TopK: of(10), ScoreThreshold: of(1.1)}) + + conf := &es8.RetrieverConfig{} + req, err := s.BuildRequest(ctx, conf, query, + retriever.WithTopK(10), + retriever.WithScoreThreshold(1.1)) + convey.So(err, convey.ShouldBeNil) convey.So(req, convey.ShouldNotBeNil) b, err := json.Marshal(req)