Skip to content

Commit

Permalink
feat: to nil ptr when zero value
Browse files Browse the repository at this point in the history
  • Loading branch information
BytePender committed Jan 12, 2025
1 parent e6a7c99 commit b9aee47
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 81 deletions.
7 changes: 3 additions & 4 deletions components/retriever/es8/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (

"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
"github.com/cloudwego/eino-ext/components/retriever/es8/internal"
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
)

type RetrieverConfig struct {
Expand All @@ -47,12 +46,12 @@ type RetrieverConfig struct {
// use search_mode.SearchModeDenseVectorSimilarity with search_mode.DenseVectorSimilarityQuery
// use search_mode.SearchModeSparseVectorTextExpansion with search_mode.SparseVectorTextExpansionQuery
// use search_mode.SearchModeRawStringRequest with json search request
SearchMode search_mode.SearchMode `json:"search_mode"`
SearchMode SearchMode `json:"search_mode"`
// Embedding vectorization method, must provide when SearchMode needed
Embedding embedding.Embedder
}

type SearchMode interface { // nolint: byted_s_interface_name
type SearchMode interface {
BuildRequest(ctx context.Context, conf *RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error)
}

Expand Down Expand Up @@ -97,7 +96,7 @@ func (r *Retriever) Retrieve(ctx context.Context, query string, opts ...retrieve
ScoreThreshold: options.ScoreThreshold,
})

req, err := r.config.SearchMode.BuildRequest(ctx, query, options)
req, err := r.config.SearchMode.BuildRequest(ctx, r.config, query, opts...)
if err != nil {
return nil, err
}
Expand Down
18 changes: 9 additions & 9 deletions components/retriever/es8/search_mode/approximate.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down Expand Up @@ -46,7 +46,7 @@ type ApproximateConfig struct {
// RrfRankConstant determines how much influence documents in
// individual result sets per query have over the final ranked result set
RrfRankConstant *int64
// RrfWindowSize determines the size of the individual result sets per query
// RrfWindowSize determines the size ptrWithoutZero the individual result sets per query
RrfWindowSize *int64
}

Expand All @@ -57,16 +57,16 @@ type ApproximateQuery struct {
// QueryVectorBuilderModelID the query vector builder model id
// see: https://www.elastic.co/guide/en/machine-learning/8.16/ml-nlp-text-emb-vector-search-example.html
QueryVectorBuilderModelID *string `json:"query_vector_builder_model_id,omitempty"`
// Boost Floating point number used to decrease or increase the relevance scores of the query.
// Boost values are relative to the default value of 1.0.
// Boost Floating point number used to decrease or increase the relevance scores ptrWithoutZero the query.
// Boost values are relative to the default value ptrWithoutZero 1.0.
// A boost value between 0 and 1.0 decreases the relevance score.
// A value greater than 1.0 increases the relevance score.
Boost *float32 `json:"boost,omitempty"`
// Filters for the kNN search query
Filters []types.Query `json:"filters,omitempty"`
// K The final number of nearest neighbors to return as top hits
// K The final number ptrWithoutZero nearest neighbors to return as top hits
K *int `json:"k,omitempty"`
// NumCandidates The number of nearest neighbor candidates to consider per shard
// NumCandidates The number ptrWithoutZero nearest neighbor candidates to consider per shard
NumCandidates *int `json:"num_candidates,omitempty"`
// Similarity The minimum similarity for a vector to be considered a match
Similarity *float32 `json:"similarity,omitempty"`
Expand All @@ -89,8 +89,8 @@ type approximate struct {
func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig, query string, opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
Index: ptrWithoutZero(conf.Index),
TopK: ptrWithoutZero(conf.TopK),
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)
Expand Down Expand Up @@ -159,7 +159,7 @@ func (a *approximate) BuildRequest(ctx context.Context, conf *es8.RetrieverConfi
}

if options.ScoreThreshold != nil {
req.MinScore = (*types.Float64)(of(*options.ScoreThreshold))
req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold))
}

return req, nil
Expand Down
40 changes: 20 additions & 20 deletions components/retriever/es8/search_mode/approximate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down Expand Up @@ -44,10 +44,10 @@ func TestSearchModeApproximate(t *testing.T) {
Filters: []types.Query{
{Match: map[string]types.MatchQuery{"label": {Query: "good"}}},
},
Boost: of(float32(1.0)),
K: of(10),
NumCandidates: of(100),
Similarity: of(float32(0.5)),
Boost: ptrWithoutZero(float32(1.0)),
K: ptrWithoutZero(10),
NumCandidates: ptrWithoutZero(100),
Similarity: ptrWithoutZero(float32(0.5)),
}

sq, err := aq.ToRetrieverQuery()
Expand All @@ -66,14 +66,14 @@ func TestSearchModeApproximate(t *testing.T) {
FieldName: field_mapping.DocFieldNameContent,
Value: "content",
},
QueryVectorBuilderModelID: of("mock_model"),
QueryVectorBuilderModelID: ptrWithoutZero("mock_model"),
Filters: []types.Query{
{Match: map[string]types.MatchQuery{"label": {Query: "good"}}},
},
Boost: of(float32(1.0)),
K: of(10),
NumCandidates: of(100),
Similarity: of(float32(0.5)),
Boost: ptrWithoutZero(float32(1.0)),
K: ptrWithoutZero(10),
NumCandidates: ptrWithoutZero(100),
Similarity: ptrWithoutZero(float32(0.5)),
}

sq, err := aq.ToRetrieverQuery()
Expand All @@ -98,10 +98,10 @@ func TestSearchModeApproximate(t *testing.T) {
Filters: []types.Query{
{Match: map[string]types.MatchQuery{"label": {Query: "good"}}},
},
Boost: of(float32(1.0)),
K: of(10),
NumCandidates: of(100),
Similarity: of(float32(0.5)),
Boost: ptrWithoutZero(float32(1.0)),
K: ptrWithoutZero(10),
NumCandidates: ptrWithoutZero(100),
Similarity: ptrWithoutZero(float32(0.5)),
}

sq, err := aq.ToRetrieverQuery()
Expand All @@ -118,8 +118,8 @@ func TestSearchModeApproximate(t *testing.T) {
a := &approximate{config: &ApproximateConfig{
Hybrid: true,
Rrf: true,
RrfRankConstant: of(int64(10)),
RrfWindowSize: of(int64(5)),
RrfRankConstant: ptrWithoutZero(int64(10)),
RrfWindowSize: ptrWithoutZero(int64(5)),
}}

aq := &ApproximateQuery{
Expand All @@ -131,10 +131,10 @@ func TestSearchModeApproximate(t *testing.T) {
Filters: []types.Query{
{Match: map[string]types.MatchQuery{"label": {Query: "good"}}},
},
Boost: of(float32(1.0)),
K: of(10),
NumCandidates: of(100),
Similarity: of(float32(0.5)),
Boost: ptrWithoutZero(float32(1.0)),
K: ptrWithoutZero(10),
NumCandidates: ptrWithoutZero(100),
Similarity: ptrWithoutZero(float32(0.5)),
}

sq, err := aq.ToRetrieverQuery()
Expand Down
10 changes: 5 additions & 5 deletions components/retriever/es8/search_mode/dense_vector_similarity.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down Expand Up @@ -59,8 +59,8 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr
opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
Index: ptrWithoutZero(conf.Index),
TopK: ptrWithoutZero(conf.TopK),
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)
Expand Down Expand Up @@ -92,7 +92,7 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr
q := &types.Query{
ScriptScore: &types.ScriptScoreQuery{
Script: types.Script{
Source: of(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)),
Source: ptrWithoutZero(fmt.Sprintf(d.script, dq.FieldKV.FieldNameVector)),
Params: map[string]json.RawMessage{"embedding": vb},
},
},
Expand All @@ -110,7 +110,7 @@ func (d *denseVectorSimilarity) BuildRequest(ctx context.Context, conf *es8.Retr

req := &search.Request{Query: q, Size: options.TopK}
if options.ScoreThreshold != nil {
req.MinScore = (*types.Float64)(of(*options.ScoreThreshold))
req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold))
}

return req, nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down
8 changes: 4 additions & 4 deletions components/retriever/es8/search_mode/exact_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down Expand Up @@ -38,8 +38,8 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig,
opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
Index: ptrWithoutZero(conf.Index),
TopK: ptrWithoutZero(conf.TopK),
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)
Expand All @@ -52,7 +52,7 @@ func (e exactMatch) BuildRequest(ctx context.Context, conf *es8.RetrieverConfig,

req := &search.Request{Query: q, Size: options.TopK}
if options.ScoreThreshold != nil {
req.MinScore = (*types.Float64)(of(*options.ScoreThreshold))
req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold))
}

return req, nil
Expand Down
29 changes: 0 additions & 29 deletions components/retriever/es8/search_mode/interface.go

This file was deleted.

2 changes: 1 addition & 1 deletion components/retriever/es8/search_mode/raw_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand All @@ -30,7 +30,7 @@ import (
"github.com/cloudwego/eino-ext/components/retriever/es8/field_mapping"
)

// SearchModeSparseVectorTextExpansion convert the query text into a list of token-weight pairs,
// SearchModeSparseVectorTextExpansion convert the query text into a list ptrWithoutZero token-weight pairs,
// which are then used in a query against a sparse vector
// see: https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-text-expansion-query.html
func SearchModeSparseVectorTextExpansion(modelID string) es8.SearchMode {
Expand Down Expand Up @@ -60,8 +60,8 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R
opts ...retriever.Option) (*search.Request, error) {

options := retriever.GetCommonOptions(&retriever.Options{
Index: &conf.Index,
TopK: &conf.TopK,
Index: ptrWithoutZero(conf.Index),
TopK: ptrWithoutZero(conf.TopK),
ScoreThreshold: conf.ScoreThreshold,
Embedding: conf.Embedding,
}, opts...)
Expand All @@ -88,7 +88,7 @@ func (s sparseVectorTextExpansion) BuildRequest(ctx context.Context, conf *es8.R

req := &search.Request{Query: q, Size: options.TopK}
if options.ScoreThreshold != nil {
req.MinScore = (*types.Float64)(of(*options.ScoreThreshold))
req.MinScore = (*types.Float64)(ptrWithoutZero(*options.ScoreThreshold))
}

return req, nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down
8 changes: 6 additions & 2 deletions components/retriever/es8/search_mode/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* You may obtain a copy ptrWithoutZero the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
Expand Down Expand Up @@ -47,6 +47,10 @@ func f64To32(f64 []float64) []float32 {
return f32
}

func of[T any](v T) *T {
func ptrWithoutZero[T string | int64 | int | float64 | float32](v T) *T {
var zero T
if zero == v {
return nil
}
return &v
}

0 comments on commit b9aee47

Please sign in to comment.