Skip to content

Commit

Permalink
feat: redefine es8 retriever's SearchMode
Browse files Browse the repository at this point in the history
  • Loading branch information
BytePender committed Jan 12, 2025
1 parent 1881930 commit e6a7c99
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 25 deletions.
4 changes: 4 additions & 0 deletions components/retriever/es8/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ type RetrieverConfig struct {
Embedding embedding.Embedder
}

type SearchMode interface { // nolint: byted_s_interface_name
BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error)
}

type Retriever struct {
client *elasticsearch.TypedClient
config *RetrieverConfig
Expand Down
13 changes: 11 additions & 2 deletions components/retriever/es8/search_mode/approximate.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import (

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

// SearchModeApproximate retrieve with multiple approximate strategy (filter+knn+rrf)
// knn: https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html
// rrf: https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
func SearchModeApproximate(config *ApproximateConfig) SearchMode {
func SearchModeApproximate(config *ApproximateConfig) es8.SearchMode {
return &approximate{config}
}

Expand Down Expand Up @@ -85,7 +86,15 @@ type approximate struct {
config *ApproximateConfig
}

func (a *approximate) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) {
func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)

var appReq ApproximateQuery
if err := json.Unmarshal([]byte(query), &appReq); err != nil {
return nil, fmt.Errorf("[BuildRequest][SearchModeApproximate] parse query failed, %w", err)
Expand Down
17 changes: 10 additions & 7 deletions components/retriever/es8/search_mode/approximate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/cloudwego/eino/components/embedding"
"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

Expand Down Expand Up @@ -78,7 +79,8 @@ func TestSearchModeApproximate(t *testing.T) {
sq, err := aq.ToRetrieverQuery()
convey.So(err, convey.ShouldBeNil)

req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil})
conf := &es8.RetrieverConfig{}
req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil))
convey.So(err, convey.ShouldBeNil)
b, err := json.Marshal(req)
convey.So(err, convey.ShouldBeNil)
Expand All @@ -104,7 +106,8 @@ func TestSearchModeApproximate(t *testing.T) {

sq, err := aq.ToRetrieverQuery()
convey.So(err, convey.ShouldBeNil)
req, err := a.BuildRequest(ctx, sq, &retriever.Options{Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}})
conf := &es8.RetrieverConfig{}
req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}))
convey.So(err, convey.ShouldBeNil)
b, err := json.Marshal(req)
convey.So(err, convey.ShouldBeNil)
Expand Down Expand Up @@ -136,11 +139,11 @@ func TestSearchModeApproximate(t *testing.T) {

sq, err := aq.ToRetrieverQuery()
convey.So(err, convey.ShouldBeNil)
req, err := a.BuildRequest(ctx, sq, &retriever.Options{
Embedding: &mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}},
TopK: of(10),
ScoreThreshold: of(1.1),
})

conf := &es8.RetrieverConfig{}
req, err := a.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}),
retriever.WithTopK(10),
retriever.WithScoreThreshold(1.1))
convey.So(err, convey.ShouldBeNil)
b, err := json.Marshal(req)
convey.So(err, convey.ShouldBeNil)
Expand Down
14 changes: 12 additions & 2 deletions components/retriever/es8/search_mode/dense_vector_similarity.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import (

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

// SearchModeDenseVectorSimilarity calculate embedding similarity between dense_vector field and query
// see: https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html#vector-functions
func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) SearchMode {
func SearchModeDenseVectorSimilarity(typ DenseVectorSimilarityType) es8.SearchMode {
return &denseVectorSimilarity{script: denseVectorScriptMap[typ]}
}

Expand All @@ -54,7 +55,16 @@ type denseVectorSimilarity struct {
script string
}

func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) {
func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string,
opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)

var dq DenseVectorSimilarityQuery
if err := json.Unmarshal([]byte(query), &dq); err != nil {
return nil, fmt.Errorf("[BuildRequest][SearchModeDenseVectorSimilarity] parse query failed, %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

Expand Down Expand Up @@ -66,13 +67,16 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) {
sq, _ := dq.ToRetrieverQuery()

PatchConvey("test embedding not provided", func() {
req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: nil})

conf := &es8.RetrieverConfig{}
req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(nil))
convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] embedding not provided")
convey.So(req, convey.ShouldBeNil)
})

PatchConvey("test vector size invalid", func() {
req, err := d.BuildRequest(ctx, sq, &retriever.Options{Embedding: mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}}})
conf := &es8.RetrieverConfig{}
req, err := d.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(mockEmbedding{size: 2, mockVector: []float64{1.1, 1.2}}))
convey.So(err, convey.ShouldBeError, "[BuildRequest][SearchModeDenseVectorSimilarity] vector size invalid, expect=1, got=2")
convey.So(req, convey.ShouldBeNil)
})
Expand All @@ -87,11 +91,12 @@ func TestSearchModeDenseVectorSimilarity(t *testing.T) {

for typ, exp := range typ2Exp {
similarity := &denseVectorSimilarity{script: denseVectorScriptMap[typ]}
req, err := similarity.BuildRequest(ctx, sq, &retriever.Options{
Embedding: mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}},
TopK: of(10),
ScoreThreshold: of(1.1),
})

conf := &es8.RetrieverConfig{}
req, err := similarity.BuildRequest(ctx, conf, sq, retriever.WithEmbedding(&mockEmbedding{size: 1, mockVector: []float64{1.1, 1.2}}),
retriever.WithTopK(10),
retriever.WithScoreThreshold(1.1))

convey.So(err, convey.ShouldBeNil)
b, err := json.Marshal(req)
convey.So(err, convey.ShouldBeNil)
Expand Down
14 changes: 12 additions & 2 deletions components/retriever/es8/search_mode/exact_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,26 @@ import (

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

func SearchModeExactMatch() SearchMode {
func SearchModeExactMatch() es8.SearchMode {
return &exactMatch{}
}

type exactMatch struct{}

func (e exactMatch) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) {
func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string,
opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)

q := &types.Query{
Match: map[string]types.MatchQuery{
field_mapping.DocFieldNameContent: {Query: query},
Expand Down
8 changes: 6 additions & 2 deletions components/retriever/es8/search_mode/raw_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@ import (
"github.com/elastic/go-elasticsearch/v8/typedapi/core/search"

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
)

func SearchModeRawStringRequest() SearchMode {
func SearchModeRawStringRequest() es8.SearchMode {
return &rawString{}
}

type rawString struct{}

func (r rawString) BuildRequest(_ context.Context, query string, _ *retriever.Options) (*search.Request, error) {
func (r rawString) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string,
opts ...retriever.Option) (*search.Request, error) {

req, err := search.NewRequest().FromJSON(query)
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ import (

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

// SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs,
// which are then used in a query against a sparse vector
// see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html
func SearchModeSparseVectorTextExpansion(modelID string) SearchMode {
func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode {
return &sparseVectorTextExpansion{modelID}
}

Expand All @@ -55,7 +56,16 @@ type sparseVectorTextExpansion struct {
modelID string
}

func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) {
func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string,
opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)

var sq SparseVectorTextExpansionQuery
if err := json.Unmarshal([]byte(query), &sq); err != nil {
return nil, fmt.Errorf("[BuildRequest][SearchModeSparseVectorTextExpansion] parse query failed, %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/cloudwego/eino/components/retriever"

"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

Expand Down Expand Up @@ -65,7 +66,12 @@ func TestSearchModeSparseVectorTextExpansion(t *testing.T) {
}

query, _ := sq.ToRetrieverQuery()
req, err := s.BuildRequest(ctx, query, &retriever.Options{TopK: of(10), ScoreThreshold: of(1.1)})

conf := &es8.RetrieverConfig{}
req, err := s.BuildRequest(ctx, conf, query,
retriever.WithTopK(10),
retriever.WithScoreThreshold(1.1))

convey.So(err, convey.ShouldBeNil)
convey.So(req, convey.ShouldNotBeNil)
b, err := json.Marshal(req)
Expand Down

0 comments on commit e6a7c99

Please sign in to comment.