diff --git a/components/retriever/es8/retriever.go b/components/retriever/es8/retriever.go index 82063ff..b441f08 100644 --- a/components/retriever/es8/retriever.go +++ b/components/retriever/es8/retriever.go @@ -31,7 +31,6 @@ import ( "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" "github.com/cloudwego/eino-ext/components/retriever/es8/internal" - "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" ) type RetrieverConfig struct { @@ -47,12 +46,12 @@ type RetrieverConfig struct { // use search_mode.SearchModeDenseVectorSimilarity with search_mode.DenseVectorSimilarityQuery // use search_mode.SearchModeSparseVectorTextExpansion with search_mode.SparseVectorTextExpansionQuery // use search_mode.SearchModeRawStringRequest with json search request - SearchMode search_mode.SearchMode `json:"search_mode"` + SearchMode SearchMode `json:"search_mode"` // Embedding vectorization method, must provide when SearchMode needed Embedding embedding.Embedder } -type SearchMode interface { // nolint: byted_s_interface_name +type SearchMode interface { BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) } @@ -97,7 +96,7 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve ScoreThreshold: options.ScoreThreshold, }) - req, err := r.config.SearchMode.BuildRequest(ctx, query, options) + req, err := r.config.SearchMode.BuildRequest(ctx, r.config, query, opts...) if err != nil { return nil, err } diff --git a/components/retriever/es8/search_mode/approximate.go b/components/retriever/es8/search_mode/approximate.go index 163c332..0474b55 100644 --- a/components/retriever/es8/search_mode/approximate.go +++ b/components/retriever/es8/search_mode/approximate.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -46,7 +46,7 @@ type ApproximateConfig struct { // RrfRankConstant determines how much influence documents in // individual result sets per query have over the final ranked result set RrfRankConstant *int64 - // RrfWindowSize determines the size of the individual result sets per query + // RrfWindowSize determines the size ptrWithoutZero the individual result sets per query RrfWindowSize *int64 } @@ -57,16 +57,16 @@ type ApproximateQuery struct { // QueryVectorBuilderModelID the query vector builder model id // see: https://www.elastic.co/guide/en/machine-learning/8.16/ml-nlp-text-emb-vector-search-example.html QueryVectorBuilderModelID *string `json:"query_vector_builder_model_id,omitempty"` - // Boost Floating point number used to decrease or increase the relevance scores of the query. - // Boost values are relative to the default value of 1.0. + // Boost Floating point number used to decrease or increase the relevance scores ptrWithoutZero the query. + // Boost values are relative to the default value ptrWithoutZero 1.0. // A boost value between 0 and 1.0 decreases the relevance score. // A value greater than 1.0 increases the relevance score. Boost *float32 `json:"boost,omitempty"` // Filters for the kNN search query Filters []types.Query `json:"filters,omitempty"` - // K The final number of nearest neighbors to return as top hits + // K The final number ptrWithoutZero nearest neighbors to return as top hits K *int `json:"k,omitempty"` - // NumCandidates The number of nearest neighbor candidates to consider per shard + // NumCandidates The number ptrWithoutZero nearest neighbor candidates to consider per shard NumCandidates *int `json:"num_candidates,omitempty"` // Similarity The minimum similarity for a vector to be considered a match Similarity *float32 `json:"similarity,omitempty"` @@ -89,8 +89,8 @@ type approximate struct { func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -159,7 +159,7 @@ func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfi } if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/approximate_test.go b/components/retriever/es8/search_mode/approximate_test.go index 634a4ba..4d79280 100644 --- a/components/retriever/es8/search_mode/approximate_test.go +++ b/components/retriever/es8/search_mode/approximate_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -44,10 +44,10 @@ func TestSearchModeApproximate(t *testing.T) { Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() @@ -66,14 +66,14 @@ func TestSearchModeApproximate(t *testing.T) { FieldName: field_mapping.DocFieldNameContent, Value: "content", }, - QueryVectorBuilderModelID: of("mock_model"), + QueryVectorBuilderModelID: ptrWithoutZero("mock_model"), Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() @@ -98,10 +98,10 @@ func TestSearchModeApproximate(t *testing.T) { Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() @@ -118,8 +118,8 @@ func TestSearchModeApproximate(t *testing.T) { a := &approximate{config: &ApproximateConfig{ Hybrid: true, Rrf: true, - RrfRankConstant: of(int64(10)), - RrfWindowSize: of(int64(5)), + RrfRankConstant: ptrWithoutZero(int64(10)), + RrfWindowSize: ptrWithoutZero(int64(5)), }} aq := &ApproximateQuery{ @@ -131,10 +131,10 @@ func TestSearchModeApproximate(t *testing.T) { Filters: []types.Query{ {Match: map[string]types.MatchQuery{"label": {Query: "good"}}}, }, - Boost: of(float32(1.0)), - K: of(10), - NumCandidates: of(100), - Similarity: of(float32(0.5)), + Boost: ptrWithoutZero(float32(1.0)), + K: ptrWithoutZero(10), + NumCandidates: ptrWithoutZero(100), + Similarity: ptrWithoutZero(float32(0.5)), } sq, err := aq.ToRetrieverQuery() diff --git a/components/retriever/es8/search_mode/dense_vector_similarity.go b/components/retriever/es8/search_mode/dense_vector_similarity.go index 07b54d5..40c4333 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -59,8 +59,8 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -92,7 +92,7 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr q := &types.Query{ ScriptScore: &types.ScriptScoreQuery{ Script: types.Script{ - Source: of(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), + Source: ptrWithoutZero(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)), Params: map[string]json.RawMessage{"embedding": vb}, }, }, @@ -110,7 +110,7 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr req := &search.Request{Query: q, Size: options.TopK} if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/dense_vector_similarity_test.go b/components/retriever/es8/search_mode/dense_vector_similarity_test.go index 648dd05..72af92b 100644 --- a/components/retriever/es8/search_mode/dense_vector_similarity_test.go +++ b/components/retriever/es8/search_mode/dense_vector_similarity_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/components/retriever/es8/search_mode/exact_match.go b/components/retriever/es8/search_mode/exact_match.go index 8f5f731..20e178c 100644 --- a/components/retriever/es8/search_mode/exact_match.go +++ b/components/retriever/es8/search_mode/exact_match.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -38,8 +38,8 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -52,7 +52,7 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, req := &search.Request{Query: q, Size: options.TopK} if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/interface.go b/components/retriever/es8/search_mode/interface.go deleted file mode 100644 index ab4ef5a..0000000 --- a/components/retriever/es8/search_mode/interface.go +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package search_mode - -import ( - "context" - - "github.com/elastic/go-elasticsearch/v8/typedapi/core/search" - - "github.com/cloudwego/eino/components/retriever" -) - -type SearchMode interface { // nolint: byted_s_interface_name - BuildRequest(ctx context.Context, query string, options *retriever.Options) (*search.Request, error) -} diff --git a/components/retriever/es8/search_mode/raw_string.go b/components/retriever/es8/search_mode/raw_string.go index 674d0a2..855840c 100644 --- a/components/retriever/es8/search_mode/raw_string.go +++ b/components/retriever/es8/search_mode/raw_string.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go index 7476bb0..f4c2c42 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -30,7 +30,7 @@ import ( "github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping" ) -// SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs, +// SearchModeSparseVectorTextExpansion convert the query text into a list ptrWithoutZero token-weight pairs, // which are then used in a query against a sparse vector // see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode { @@ -60,8 +60,8 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R opts ...retriever.Option) (*search.Request, error) { options := retriever.GetCommonOptions(&retriever.Options{ - Index: &conf.Index, - TopK: &conf.TopK, + Index: ptrWithoutZero(conf.Index), + TopK: ptrWithoutZero(conf.TopK), ScoreThreshold: conf.ScoreThreshold, Embedding: conf.Embedding, }, opts...) @@ -88,7 +88,7 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R req := &search.Request{Query: q, Size: options.TopK} if options.ScoreThreshold != nil { - req.MinScore = (*types.Float64)(of(*options.ScoreThreshold)) + req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold)) } return req, nil diff --git a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go index e59ba2e..b3f5826 100644 --- a/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go +++ b/components/retriever/es8/search_mode/sparse_vector_text_expansion_test.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/components/retriever/es8/search_mode/utils.go b/components/retriever/es8/search_mode/utils.go index 7501ada..b69ed15 100644 --- a/components/retriever/es8/search_mode/utils.go +++ b/components/retriever/es8/search_mode/utils.go @@ -3,7 +3,7 @@ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * You may obtain a copy ptrWithoutZero the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -47,6 +47,10 @@ func f64To32(f64 []float64) []float32 { return f32 } -func of[T any](v T) *T { +func ptrWithoutZero[T string | int64 | int | float64 | float32](v T) *T { + var zero T + if zero == v { + return nil + } return &v }