Skip to content

Commit

Permalink
新增知识库检索
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj committed Dec 24, 2024
1 parent d55e6fc commit 60d0d27
Show file tree
Hide file tree
Showing 12 changed files with 891 additions and 19 deletions.
36 changes: 36 additions & 0 deletions go/appbuilder/knowledge_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,3 +814,39 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk

return rsp, nil
}

func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) {
request := http.Request{}
header := t.sdkConfig.AuthHeaderV2()
serviceURL, err := t.sdkConfig.ServiceURLV2("/knowledgebases/query")
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
request.URL = serviceURL
request.Method = "POST"
header.Set("Content-Type", "application/json")
request.Header = header
data, _ := json.Marshal(req)
request.Body = NopCloser(bytes.NewReader(data))
t.sdkConfig.BuildCurlCommand(&request)
resp, err := t.client.Do(&request)
if err != nil {
return QueryKnowledgeBaseResponse{}, err
}
defer resp.Body.Close()
requestID, err := checkHTTPResponse(resp)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}
data, err = io.ReadAll(resp.Body)
if err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

rsp := QueryKnowledgeBaseResponse{}
if err := json.Unmarshal(data, &rsp); err != nil {
return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err)
}

return rsp, nil
}
86 changes: 86 additions & 0 deletions go/appbuilder/knowledge_base_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,89 @@ type DescribeChunksResponse struct {
NextMarker string `json:"nextMarker"`
MaxKeys int `json:"maxKeys"`
}

type MetadataFilter struct {
Operator string `json:"operator"`
Field string `json:"field"`
Value any `json:"value"` // 因为Value的类型可以是str或list[str],所以我们使用interface{}来表示任何类型
}

type MetadataFilters struct {
Filters []MetadataFilter `json:"filters"`
Condition string `json:"condition"`
}

type PreRankingConfig struct {
Bm25Weight float64 `json:"bm25_weight"`
VecWeight float64 `json:"vec_weight"`
Bm25B float64 `json:"bm25_b"`
Bm25K1 float64 `json:"bm25_k1"`
Bm25MaxScore float64 `json:"bm25_max_score"`
}

type ElasticSearchRetrieveConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Threshold float64 `json:"threshold"`
Top int `json:"top"`
}

type RankingConfig struct {
Name string `json:"name"`
Type string `json:"type"`
Inputs []string `json:"inputs"`
ModelName string `json:"model_name"`
Top int `json:"top"`
}

type QueryPipelineConfig struct {
ID string `json:"id"`
Pipeline []any `json:"pipeline"`
}

type QueryKnowledgeBaseRequest struct {
Query string `json:"query"`
Type string `json:"type"`
Top int `json:"top"`
Skip int `json:"skip"`
KnowledgebaseIDs []string `json:"knowledgebase_ids"`
MetadataFileters MetadataFilters `json:"metadata_fileters"`
PipelineConfig QueryPipelineConfig `json:"pipeline_config"`
}

type RowLine struct {
Key string `json:"key"`
Index int `json:"index"`
Value string `json:"value"`
EnableIndexing bool `json:"enable_indexing"`
EnableResponse bool `json:"enable_response"`
}

type ChunkLocation struct {
PageNum []int `json:"paget_num"`
Box [][]int `json:"box"`
}

type Chunk struct {
ChunkID string `json:"chunk_id"`
KnowledgebaseID string `json:"knowledgebase_id"`
DocumentID string `json:"document_id"`
DocumentName string `json:"document_name"`
Meta map[string]any `json:"meta"`
Type string `json:"type"`
Content string `json:"content"`
CreateTime string `json:"create_time"`
UpdateTime string `json:"update_time"`
RetrievalScore float64 `json:"retrieval_score"`
RankScore float64 `json:"rank_score"`
Locations ChunkLocation `json:"locations"`
Children []Chunk `json:"children"`
}

type QueryKnowledgeBaseResponse struct {
RequestId string `json:"requestId"`
Code string `json:"code"`
Message string `json:"message"`
Chunks []Chunk `json:"chunks"`
TotalCount int `json:"total_count"`
}
92 changes: 92 additions & 0 deletions go/appbuilder/knowledge_base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package appbuilder

import (
"bytes"
"encoding/json"
"fmt"
"os"
"strings"
Expand Down Expand Up @@ -1478,3 +1479,94 @@ func TestChunk(t *testing.T) {
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}

func TestQueryKnowledgeBase(t *testing.T) {
t.Parallel() // 并发运行
var logBuffer bytes.Buffer

os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG")

log := func(format string, args ...any) {
fmt.Fprintf(&logBuffer, format+"\n", args...)
}

config, err := NewSDKConfig("", os.Getenv(SecretKey))
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new http client config failed: %v", err)
}

client, err := NewKnowledgeBase(config)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("new Knowledge base instance failed")
}

jsonStr := `
{
"type": "fulltext",
"query": "民法典第三条",
"knowledgebase_ids": [
"70c6375a-1595-41f2-9a3b-e81bc9060b7f"
],
"metadata_filters": {
"filters": [
],
"condition": "or"
},
"pipeline_config": {
"id": "pipeline_001",
"pipeline": [
{
"name": "step1",
"type": "elastic_search",
"threshold": 0.1,
"top": 400,
"pre_ranking": {
"bm25_weight": 0.25,
"vec_weight": 0.75,
"bm25_b": 0.75,
"bm25_k1": 1.5,
"bm25_max_score": 50
}
},
{
"name": "step2",
"type": "ranking",
"inputs": ["step1"],
"model_name": "ranker-v1",
"top": 20
}
]
},
"top": 5,
"skip": 0
}`

var request QueryKnowledgeBaseRequest
err = json.Unmarshal([]byte(jsonStr), &request)
if err != nil {
fmt.Println("unmarshal tool error:", err)
}

queryKnowledgeBaseResponse, err := client.QueryKnowledgeBase(request)
if err != nil {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("create chunk failed: %v", err)
}
chunkID := queryKnowledgeBaseResponse.Chunks[0].ChunkID
log("query got chunk ID: %s", chunkID)
if len(chunkID) == 0 {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
t.Fatalf("query knowledge base failed: %v", err)
}

// 如果测试失败,则输出缓冲区中的日志
if t.Failed() {
t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m")
fmt.Println(logBuffer.String())
} else { // else 紧跟在右大括号后面
// 测试通过,打印文件名和测试函数名
t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public class AppBuilderConfig {
public static final String CHUNKS_DESCRIBE_URL = "/knowledgeBase?Action=DescribeChunks";
// 删除切片
public static final String CHUNK_DELETE_URL = "/knowledgeBase?Action=DeleteChunk";
// 知识库检索
public static final String QUERY_KNOWLEDGEBASE_URL = "/knowledgebases/query";


// 运行rag
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,4 +691,17 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I
ChunksDescribeResponse respBody = response.getBody();
return respBody;
}

public QueryKnowledgeBaseResponse queryKnowledgeBaseResponse(QueryKnowledgeBaseRequest request)
throws IOException, AppBuilderServerException {
String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL;

String jsonBody = JsonUtils.serialize(request);
ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url,
new StringEntity(jsonBody, StandardCharsets.UTF_8));
postRequest.setHeader("Content-Type", "application/json");
HttpResponse<QueryKnowledgeBaseResponse> response = httpClient.execute(postRequest, QueryKnowledgeBaseResponse.class);
QueryKnowledgeBaseResponse respBody = response.getBody();
return respBody;
}
}
Loading

0 comments on commit 60d0d27

Please sign in to comment.