diff --git a/go/appbuilder/knowledge_base.go b/go/appbuilder/knowledge_base.go index f7ef3ddf8..55b0280ea 100644 --- a/go/appbuilder/knowledge_base.go +++ b/go/appbuilder/knowledge_base.go @@ -814,3 +814,39 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk return rsp, nil } + +func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) { + request := http.Request{} + header := t.sdkConfig.AuthHeaderV2() + serviceURL, err := t.sdkConfig.ServiceURLV2("/knowledgebases/query") + if err != nil { + return QueryKnowledgeBaseResponse{}, err + } + request.URL = serviceURL + request.Method = "POST" + header.Set("Content-Type", "application/json") + request.Header = header + data, _ := json.Marshal(req) + request.Body = NopCloser(bytes.NewReader(data)) + t.sdkConfig.BuildCurlCommand(&request) + resp, err := t.client.Do(&request) + if err != nil { + return QueryKnowledgeBaseResponse{}, err + } + defer resp.Body.Close() + requestID, err := checkHTTPResponse(resp) + if err != nil { + return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + data, err = io.ReadAll(resp.Body) + if err != nil { + return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + + rsp := QueryKnowledgeBaseResponse{} + if err := json.Unmarshal(data, &rsp); err != nil { + return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + + return rsp, nil +} diff --git a/go/appbuilder/knowledge_base_data.go b/go/appbuilder/knowledge_base_data.go index 4e9ab6d14..888939f0b 100644 --- a/go/appbuilder/knowledge_base_data.go +++ b/go/appbuilder/knowledge_base_data.go @@ -245,3 +245,89 @@ type DescribeChunksResponse struct { NextMarker string `json:"nextMarker"` MaxKeys int `json:"maxKeys"` } + +type MetadataFilter struct { + Operator string `json:"operator"` + Field string `json:"field"` + Value any `json:"value"` // 因为Value的类型可以是str或list[str],所以我们使用interface{}来表示任何类型 +} + +type MetadataFilters struct { + Filters []MetadataFilter `json:"filters"` + Condition string `json:"condition"` +} + +type PreRankingConfig struct { + Bm25Weight float64 `json:"bm25_weight"` + VecWeight float64 `json:"vec_weight"` + Bm25B float64 `json:"bm25_b"` + Bm25K1 float64 `json:"bm25_k1"` + Bm25MaxScore float64 `json:"bm25_max_score"` +} + +type ElasticSearchRetrieveConfig struct { + Name string `json:"name"` + Type string `json:"type"` + Threshold float64 `json:"threshold"` + Top int `json:"top"` +} + +type RankingConfig struct { + Name string `json:"name"` + Type string `json:"type"` + Inputs []string `json:"inputs"` + ModelName string `json:"model_name"` + Top int `json:"top"` +} + +type QueryPipelineConfig struct { + ID string `json:"id"` + Pipeline []any `json:"pipeline"` +} + +type QueryKnowledgeBaseRequest struct { + Query string `json:"query"` + Type string `json:"type"` + Top int `json:"top"` + Skip int `json:"skip"` + KnowledgebaseIDs []string `json:"knowledgebase_ids"` + MetadataFileters MetadataFilters `json:"metadata_fileters"` + PipelineConfig QueryPipelineConfig `json:"pipeline_config"` +} + +type RowLine struct { + Key string `json:"key"` + Index int `json:"index"` + Value string `json:"value"` + EnableIndexing bool `json:"enable_indexing"` + EnableResponse bool `json:"enable_response"` +} + +type ChunkLocation struct { + PageNum []int `json:"paget_num"` + Box [][]int `json:"box"` +} + +type Chunk struct { + ChunkID string `json:"chunk_id"` + KnowledgebaseID string `json:"knowledgebase_id"` + DocumentID string `json:"document_id"` + DocumentName string `json:"document_name"` + Meta map[string]any `json:"meta"` + Type string `json:"type"` + Content string `json:"content"` + CreateTime string `json:"create_time"` + UpdateTime string `json:"update_time"` + RetrievalScore float64 `json:"retrieval_score"` + RankScore float64 `json:"rank_score"` + Locations ChunkLocation `json:"locations"` + Children []Chunk `json:"children"` +} + +type QueryKnowledgeBaseResponse struct { + RequestId string `json:"requestId"` + Code string `json:"code"` + Message string `json:"message"` + Chunks []Chunk `json:"chunks"` + TotalCount int `json:"total_count"` +} diff --git a/go/appbuilder/knowledge_base_test.go b/go/appbuilder/knowledge_base_test.go index 825c3d209..5a6ef52b6 100644 --- a/go/appbuilder/knowledge_base_test.go +++ b/go/appbuilder/knowledge_base_test.go @@ -16,6 +16,7 @@ package appbuilder import ( "bytes" + "encoding/json" "fmt" "os" "strings" @@ -1478,3 +1479,94 @@ func TestChunk(t *testing.T) { t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") } } + +func TestQueryKnowledgeBase(t *testing.T) { + t.Parallel() // 并发运行 + var logBuffer bytes.Buffer + + os.Setenv("APPBUILDER_LOGLEVEL", "DEBUG") + + log := func(format string, args ...any) { + fmt.Fprintf(&logBuffer, format+"\n", args...) + } + + config, err := NewSDKConfig("", os.Getenv(SecretKey)) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new http client config failed: %v", err) + } + + client, err := NewKnowledgeBase(config) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("new Knowledge base instance failed") + } + + jsonStr := ` +{ + "type": "fulltext", + "query": "民法典第三条", + "knowledgebase_ids": [ + "70c6375a-1595-41f2-9a3b-e81bc9060b7f" + ], + "metadata_filters": { + "filters": [ + ], + "condition": "or" + }, + "pipeline_config": { + "id": "pipeline_001", + "pipeline": [ + { + "name": "step1", + "type": "elastic_search", + "threshold": 0.1, + "top": 400, + "pre_ranking": { + "bm25_weight": 0.25, + "vec_weight": 0.75, + "bm25_b": 0.75, + "bm25_k1": 1.5, + "bm25_max_score": 50 + } + }, + { + "name": "step2", + "type": "ranking", + "inputs": ["step1"], + "model_name": "ranker-v1", + "top": 20 + } + ] + }, + "top": 5, + "skip": 0 +}` + + var request QueryKnowledgeBaseRequest + err = json.Unmarshal([]byte(jsonStr), &request) + if err != nil { + fmt.Println("unmarshal tool error:", err) + } + + queryKnowledgeBaseResponse, err := client.QueryKnowledgeBase(request) + if err != nil { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("create chunk failed: %v", err) + } + chunkID := queryKnowledgeBaseResponse.Chunks[0].ChunkID + log("query got chunk ID: %s", chunkID) + if len(chunkID) == 0 { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + t.Fatalf("query knowledge base failed: %v", err) + } + + // 如果测试失败,则输出缓冲区中的日志 + if t.Failed() { + t.Logf("%s========== FAIL: %s ==========%s", "\033[31m", t.Name(), "\033[0m") + fmt.Println(logBuffer.String()) + } else { // else 紧跟在右大括号后面 + // 测试通过,打印文件名和测试函数名 + t.Logf("%s========== OK: %s ==========%s", "\033[32m", t.Name(), "\033[0m") + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java b/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java index 33b5b40a1..bd3870fb9 100644 --- a/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java +++ b/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java @@ -72,6 +72,8 @@ public class AppBuilderConfig { public static final String CHUNKS_DESCRIBE_URL = "/knowledgeBase?Action=DescribeChunks"; // 删除切片 public static final String CHUNK_DELETE_URL = "/knowledgeBase?Action=DeleteChunk"; + // 知识库检索 + public static final String QUERY_KNOWLEDGEBASE_URL = "/knowledgebases/query"; // 运行rag diff --git a/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java b/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java index aed73056b..c582fd007 100644 --- a/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java +++ b/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java @@ -691,4 +691,17 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I ChunksDescribeResponse respBody = response.getBody(); return respBody; } + + public QueryKnowledgeBaseResponse queryKnowledgeBaseResponse(QueryKnowledgeBaseRequest request) + throws IOException, AppBuilderServerException { + String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL; + + String jsonBody = JsonUtils.serialize(request); + ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url, + new StringEntity(jsonBody, StandardCharsets.UTF_8)); + postRequest.setHeader("Content-Type", "application/json"); + HttpResponse response = httpClient.execute(postRequest, QueryKnowledgeBaseResponse.class); + QueryKnowledgeBaseResponse respBody = response.getBody(); + return respBody; + } } diff --git a/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java new file mode 100644 index 000000000..5e0d591d7 --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java @@ -0,0 +1,276 @@ +package com.baidubce.appbuilder.model.knowledgebase; + +import java.util.List; + +public class QueryKnowledgeBaseRequest { + private String query; + private String type; + private Integer top; + private Integer skip; + private String[] knowledgebase_ids; + private MetadataFilters metadata_filters; + private QueryPipelineConfig pipeline_config; + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Integer getTop() { + return top; + } + + public void setTop(Integer top) { + this.top = top; + } + + public Integer getSkip() { + return skip; + } + + public void setSkip(Integer skip) { + this.skip = skip; + } + + public String[] getKnowledgebase_ids() { + return knowledgebase_ids; + } + + public void setKnowledgebase_ids(String[] knowledgebase_ids) { + this.knowledgebase_ids = knowledgebase_ids; + } + + public MetadataFilters getMetadata_filters() { + return metadata_filters; + } + + public void setMetadata_filters(MetadataFilters metadata_filters) { + this.metadata_filters = metadata_filters; + } + + public QueryPipelineConfig getPipeline_config() { + return pipeline_config; + } + + public void setPipeline_config(QueryPipelineConfig pipeline_config) { + this.pipeline_config = pipeline_config; + } + + public static class MetadataFilter { + private String operator; + private String field; + private Object value; + + public String getOperator() { + return operator; + } + + public void setOperator(String operator) { + this.operator = operator; + } + + public String getField() { + return field; + } + + public void setField(String field) { + this.field = field; + } + + public Object getValue() { + return value; + } + + public void setValue(Object value) { + this.value = value; + } + } + + public static class MetadataFilters { + private List filters; + private String condition; + + public List getFilters() { + return filters; + } + + public void setFilters(List filters) { + this.filters = filters; + } + + public String getCondition() { + return condition; + } + + public void setCondition(String condition) { + this.condition = condition; + } + } + + public static class PreRankingConfig { + private Double bm25_weight; + private Double vec_weight; + private Double bm25_b; + private Double bm25_k1; + private Double bm25_max_score; + + public Double getBm25_weight() { + return bm25_weight; + } + + public void setBm25_weight(Double bm25_weight) { + this.bm25_weight = bm25_weight; + } + + public Double getVec_weight() { + return vec_weight; + } + + public void setVec_weight(Double vec_weight) { + this.vec_weight = vec_weight; + } + + public Double getBm25_b() { + return bm25_b; + } + + public void setBm25_b(Double bm25_b) { + this.bm25_b = bm25_b; + } + + public Double getBm25_k1() { + return bm25_k1; + } + + public void setBm25_k1(Double bm25_k1) { + this.bm25_k1 = bm25_k1; + } + + public Double getBm25_max_score() { + return bm25_max_score; + } + + public void setBm25_max_score(Double bm25_max_score) { + this.bm25_max_score = bm25_max_score; + } + } + + public static class ElasticSearchRetrieveConfig { + private String name; + private String type; + private Double threshold; + private Integer top; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Double getThreshold() { + return threshold; + } + + public void setThreshold(Double threshold) { + this.threshold = threshold; + } + + public Integer getTop() { + return top; + } + + public void setTop(Integer top) { + this.top = top; + } + } + + public static class RankingConfig { + private String name; + private String type; + private List inputs; + private String model_name; + private Integer top; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public List getInputs() { + return inputs; + } + + public void setInputs(List inputs) { + this.inputs = inputs; + } + + public String getModel_name() { + return model_name; + } + + public void setModel_name(String model_name) { + this.model_name = model_name; + } + + public Integer getTop() { + return top; + } + + public void setTop(Integer top) { + this.top = top; + } + } + + public static class QueryPipelineConfig { + private String id; + private List pipeline; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public List getPipeline() { + return pipeline; + } + + public void setPipeline(List pipeline) { + this.pipeline = pipeline; + } + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java new file mode 100644 index 000000000..8c572722f --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java @@ -0,0 +1,122 @@ +package com.baidubce.appbuilder.model.knowledgebase; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; + +public class QueryKnowledgeBaseResponse { + private String requestId; + private String code; + private String message; + private List chunks; + private int total_count; + + public String getRequestId() { return requestId; } + + public void setRequestId(String requestId) { this.requestId = requestId; } + + public String getCode() { return code; } + + public void setCode(String code) { this.code = code; } + + public String getMessage() { return message; } + + public void setMessage(String message) { this.message = message; } + + public List getChunks() { return chunks; } + + public void setChunks(List chunks) { this.chunks = chunks; } + + public int getTotal_count() { return total_count; } + + public void setTotal_count(int total_count) { this.total_count = total_count; } + + public static class Chunk { + private String chunk_id; + private String knowledgebase_id; + private String document_id; + private String document_name; + private Map meta; + private String type; + private String content; + private LocalDateTime create_time; + private LocalDateTime update_time; + private float retrieval_score; + private float rank_score; + private ChunkLocation locations; + private List children; + + public String getChunk_id() { return chunk_id; } + + public void setChunk_id(String chunk_id) { this.chunk_id = chunk_id; } + + public String getKnowledgebase_id() { return knowledgebase_id; } + + public void setKnowledgebase_id(String knowledgebase_id) { this.knowledgebase_id = knowledgebase_id; } + + public String getDocument_id() { return document_id; } + + public void setDocument_id(String document_id) { this.document_id = document_id; } + + public String getDocument_name() { return document_name; } + + public void setDocument_name(String document_name) { this.document_name = document_name; } + + public Map getMeta() { return meta; } + + public void setMeta(Map meta) { this.meta = meta; } + + public String getType() { return type; } + + public void setType(String type) { this.type = type; } + + public String getContent() { return content; } + + public void setContent(String content) { this.content = content; } + + public LocalDateTime getCreate_time() { return create_time; } + + public void setCreate_time(LocalDateTime create_time) { this.create_time = create_time; } + + public LocalDateTime getUpdate_time() { return update_time; } + + public void setUpdate_time(LocalDateTime update_time) { this.update_time = update_time; } + + public float getRetrieval_score() { return retrieval_score; } + + public void setRetrieval_score(float retrieval_score) { this.retrieval_score = retrieval_score; } + + public float getRank_score() { return rank_score; } + + public void setRank_score(float rank_score) { this.rank_score = rank_score; } + + public ChunkLocation getLocations() { return locations; } + + public void setLocations(ChunkLocation locations) { this.locations = locations; } + + public List getChildren() { return children; } + + public void setChildren(List children) { this.children = children; } + } + + public static class ChunkLocation { + private List paget_num; + private List> box; + + public List getPaget_num() { + return paget_num; + } + + public void setPaget_num(List paget_num) { + this.paget_num = paget_num; + } + + public List> getBox() { + return box; + } + + public void setBox(List> box) { + this.box = box; + } + } +} diff --git a/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java b/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java index 4431d632b..188468da8 100644 --- a/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/KnowledgebaseTest.java @@ -3,6 +3,8 @@ import static org.junit.Assert.assertNotNull; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; import org.junit.Before; import org.junit.Test; @@ -10,6 +12,8 @@ import com.baidubce.appbuilder.base.exception.AppBuilderServerException; import com.baidubce.appbuilder.console.knowledgebase.Knowledgebase; import com.baidubce.appbuilder.model.knowledgebase.*; +import com.google.gson.Gson; + import static org.junit.Assert.assertTrue; @@ -154,4 +158,16 @@ public void testCreateChunk() throws IOException, AppBuilderServerException { // 删除切片 knowledgebase.deleteChunk(chunkId); } + + @Test + public void testQueryKnowledgeBase() throws IOException, AppBuilderServerException { + System.setProperty("APPBUILDER_TOKEN", System.getenv("APPBUILDER_TOKEN")); + Knowledgebase knowledgebase = new Knowledgebase(); + // 查询知识库 + Gson gson = new Gson(); + String requestJson = new String(Files.readAllBytes(Paths.get("src/test/java/com/baidubce/appbuilder/files/query_knowledgebase.json"))); + QueryKnowledgeBaseRequest request = gson.fromJson(requestJson, QueryKnowledgeBaseRequest.class); + QueryKnowledgeBaseResponse response = knowledgebase.queryKnowledgeBaseResponse(request); + assertNotNull(response.getChunks().get(0).getChunk_id()); + } } diff --git a/java/src/test/java/com/baidubce/appbuilder/files/query_knowledgebase.json b/java/src/test/java/com/baidubce/appbuilder/files/query_knowledgebase.json new file mode 100644 index 000000000..012c75e5c --- /dev/null +++ b/java/src/test/java/com/baidubce/appbuilder/files/query_knowledgebase.json @@ -0,0 +1,40 @@ +{ + "type": "fulltext", + "query": "民法典第三条", + "knowledgebase_ids": [ + "70c6375a-1595-41f2-9a3b-e81bc9060b7f" + ], + "metadata_filters": { + "filters": [], + "condition": "or" + }, + "pipeline_config": { + "id": "pipeline_001", + "pipeline": [ + { + "name": "step1", + "type": "elastic_search", + "threshold": 0.1, + "top": 400, + "pre_ranking": { + "bm25_weight": 0.25, + "vec_weight": 0.75, + "bm25_b": 0.75, + "bm25_k1": 1.5, + "bm25_max_score": 50 + } + }, + { + "name": "step2", + "type": "ranking", + "inputs": [ + "step1" + ], + "model_name": "ranker-v1", + "top": 20 + } + ] + }, + "top": 5, + "skip": 0 +} \ No newline at end of file diff --git a/python/core/console/knowledge_base/data_class.py b/python/core/console/knowledge_base/data_class.py index ee339f825..cc689f96f 100644 --- a/python/core/console/knowledge_base/data_class.py +++ b/python/core/console/knowledge_base/data_class.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from pydantic import BaseModel -from pydantic import Field -from typing import Union -from typing import Optional -import datetime +from __future__ import annotations +from datetime import datetime +from pydantic import BaseModel, Field +from typing import Union, Optional, List class KnowledgeBaseUploadFileResponse(BaseModel): @@ -26,7 +24,8 @@ class KnowledgeBaseUploadFileResponse(BaseModel): class CustomProcessRule(BaseModel): - separators: list[str] = Field(..., description="分段符号列表", example=[",", "?"]) + separators: list[str] = Field(..., description="分段符号列表", example=[ + ",", "?"]) target_length: int = Field(..., description="分段最大长度", ge=300, le=1200) overlap_rate: float = Field( ..., description="分段重叠最大字数占比,推荐值0.25", ge=0, le=0.3, example=0.2 @@ -219,6 +218,7 @@ class KnowledgeBaseCreateDocumentsRequest(BaseModel): None, description="文档处理选项" ) + class KnowledgeBaseCreateDocumentsResponse(BaseModel): requestId: str = Field(..., description="请求ID") documentIds: list[str] = Field(..., description="文档ID列表") @@ -285,3 +285,118 @@ class DescribeChunksResponse(BaseModel): ) nextMarker: str = Field(..., description="下一页起始位置") maxKeys: int = Field(..., description="本次查询包含的最大结果集数量") + + +class MetadataFilter(BaseModel): + operator: str = Field(..., description="操作符名称。==:等于,in:在数组中,not_in:不在数组中") + field: str = Field(None, description="字段名,目前支持doc_id") + value: Union[str, list[str]] = Field( + ..., description="字段值,如果是in操作符,value为数组" + ) + + +class MetadataFilters(BaseModel): + filters: list[MetadataFilter] = Field(..., description="过滤条件") + condition: str = Field(..., description="文档组合条件。and:与,or:或") + + +class PreRankingConfig(BaseModel): + bm25_weight: float = Field( + None, description="粗排bm25比重,取值范围在 [0, 1],默认0.75" + ) + vec_weight: float = Field( + None, description="粗排向量余弦分比重,取值范围在 [0, 1],默认0.25" + ) + bm25_b: float = Field( + None, description="控制文档长度对评分影响的参数,取值范围在 [0, 1],默认0.75" + ) + bm25_k1: float = Field( + None, + description="词频饱和因子,控制词频(TF)对评分的影响,常取值范围在 [1.2, 2.0],默认1.5", + ) + bm25_max_score: float = Field( + None, description="得分归一化参数,不建议修改,默认50" + ) + + +class ElasticSearchRetrieveConfig(BaseModel): + name: str = Field(..., description="配置名称") + type: str = Field(None, description="elastic_search标志,该节点为es全文检索") + threshold: float = Field(None, description="得分阈值,默认0.1") + top: int = Field(None, description="召回数量,默认400") + + +class RankingConfig(BaseModel): + name: str = Field(..., description="配置名称") + type: str = Field(None, description="ranking标志,该节点为ranking节点") + inputs: list[str] = Field( + ..., + description='输入的节点名,如es检索配置的名称为pipeline_001,则该inputs为["pipeline_001"]', + ) + model_name: str = Field(None, description="ranking模型名(当前仅一种,暂不生效)") + top: int = Field(None, description="取切片top进行排序,默认20,最大400") + + +class QueryPipelineConfig(BaseModel): + id: str = Field( + None, description="配置唯一标识,如果用这个id,则引用已经配置好的QueryPipeline" + ) + pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig]] = Field( + None, description="配置的Pipeline,如果没有用id,可以用这个对象指定一个新的配置" + ) + + +class QueryKnowledgeBaseRequest(BaseModel): + query: str = Field(..., description="检索query") + type: str = Field(None, description="检索策略的枚举, fulltext:全文检索") + top: int = Field(None, description="返回结果数量") + skip: int = Field( + None, + description="跳过多少条记录, 通过top和skip可以实现类似分页的效果,比如top 10 skip 0,取第一页的10个,top 10 skip 10,取第二页的10个", + ) + knowledgebase_ids: list[str] = Field(..., description="知识库ID列表") + metadata_filters: MetadataFilters = Field(None, description="元数据过滤条件") + pipeline_config: QueryPipelineConfig = Field(None, description="检索配置") + + +class RowLine(BaseModel): + key: str = Field(..., description="列名") + index: int = Field(..., description="列号") + value: str = Field(..., description="列值") + enable_indexing: bool = Field(..., description="是否索引") + enable_response: bool = Field( + ..., + description="是否参与问答(即该列数据是否对大模型可见)。当前值固定为true。", + ) + + +class ChunkLocation(BaseModel): + page_num: list[int] = Field(..., description="页面") + box: list[list[int]] = Field( + ..., + description="文本内容位置,在视觉上是文本框,格式是长度为4的int数组,含义是[x, y, width, height]", + ) + + +class Chunk(BaseModel): + chunk_id: str = Field(..., description="切片ID") + knowledgebase_id: str = Field(..., description="知识库ID") + document_id: str = Field(..., description="文档ID") + document_name: str = Field(None, description="文档名称") + meta: dict = Field(None, description="文档元数据") + chunk_type: str = Field(..., description="切片类型") + content: str = Field(..., description="切片内容") + create_time: datetime = Field(..., description="创建时间") + update_time: datetime = Field(..., description="更新时间") + retrieval_score: float = Field(..., description="粗检索得分") + rank_score: float = Field(..., description="rerank得分") + locations: ChunkLocation = Field(None, description="切片位置") + children: List[Chunk] = Field(None, description="子切片") + + +class QueryKnowledgeBaseResponse(BaseModel): + requestId: str = Field(None, description="请求ID") + code: str = Field(None, description="状态码") + message: str = Field(None, description="状态信息") + chunks: list[Chunk] = Field(..., description="切片列表") + total_count: int = Field(..., description="切片总数") diff --git a/python/core/console/knowledge_base/knowledge_base.py b/python/core/console/knowledge_base/knowledge_base.py index b2f6b1775..214c6b65e 100644 --- a/python/core/console/knowledge_base/knowledge_base.py +++ b/python/core/console/knowledge_base/knowledge_base.py @@ -15,9 +15,6 @@ import os import json import uuid -from pydantic import BaseModel -from pydantic import Field -from typing import Union from typing import Optional from appbuilder.core._client import HTTPClient from appbuilder.core.console.knowledge_base import data_class @@ -897,3 +894,31 @@ def get_all_documents(self, knowledge_base_id: Optional[str] = None) -> list: doc_list.extend(response_per_time.data) return doc_list + + def query_knowledge_base( + self, + request: data_class.QueryKnowledgeBaseRequest + ) -> data_class.QueryKnowledgeBaseResponse: + """ + 检索知识库 + + Args: + request (data_class.QueryKnowledgeBaseRequest): 检索知识库的请求对象 + + Returns: + data_class.QueryKnowledgeBaseResponse: 检索知识库的响应对象 + """ + headers = self.http_client.auth_header_v2() + headers["content-type"] = "application/json" + + url = self.http_client.service_url_v2("/knowledgebases/query") + response = self.http_client.session.post( + url=url, headers=headers, json=request.model_dump(exclude_none=True) + ) + + self.http_client.check_response_header(response) + self.http_client.check_console_response(response) + data = response.json() + + resp = data_class.QueryKnowledgeBaseResponse(**data) + return resp diff --git a/python/tests/test_knowledge_base.py b/python/tests/test_knowledge_base.py index ddc0c8924..cae683f86 100644 --- a/python/tests/test_knowledge_base.py +++ b/python/tests/test_knowledge_base.py @@ -16,6 +16,8 @@ import os from appbuilder.core._exception import BadRequestException +from appbuilder.core.console.knowledge_base import data_class + class TestKnowLedge(unittest.TestCase): def setUp(self): @@ -26,7 +28,8 @@ def test_doc_knowledage(self): appbuilder.logger.setLoglevel('DEBUG') knowledge = appbuilder.KnowledgeBase(knowledge_id=dataset_id) - upload_res = knowledge.upload_file("./data/qa_appbuilder_client_demo.pdf") + upload_res = knowledge.upload_file( + "./data/qa_appbuilder_client_demo.pdf") add_res = knowledge.add_document( content_type="raw_text", file_ids=[upload_res.id], @@ -35,10 +38,11 @@ def test_doc_knowledage(self): ), ) list_res = knowledge.get_documents_list() - delete_res = knowledge.delete_document(document_id=add_res.document_ids[0]) + delete_res = knowledge.delete_document( + document_id=add_res.document_ids[0]) all_doc = knowledge.get_all_documents() self.assertIsInstance(all_doc, list) - + def test_get_documents_number_raise(self): knowledge = appbuilder.KnowledgeBase() with self.assertRaises(ValueError): @@ -49,9 +53,11 @@ def test_xlsx_knowledage(self): knowledge = appbuilder.KnowledgeBase(knowledge_id=dataset_id) upload_res = knowledge.upload_file("./data/qa_demo.xlsx") - add_res = knowledge.add_document(content_type="qa", file_ids=[upload_res.id]) + add_res = knowledge.add_document( + content_type="qa", file_ids=[upload_res.id]) list_res = knowledge.get_documents_list() - delete_res = knowledge.delete_document(document_id=add_res.document_ids[0]) + delete_res = knowledge.delete_document( + document_id=add_res.document_ids[0]) def test_create_knowledge_base(self): knowledge = appbuilder.KnowledgeBase() @@ -94,7 +100,8 @@ def test_create_knowledge_base(self): ), prependInfo=["title", "filename"], ), - knowledgeAugmentation=appbuilder.DocumentChoices(choices=["faq"]), + knowledgeAugmentation=appbuilder.DocumentChoices(choices=[ + "faq"]), ), ) self.assertIsInstance(create_documents_response.documentIds, list) @@ -117,12 +124,14 @@ def test_create_knowledge_base(self): ), prependInfo=["title", "filename"], ), - knowledgeAugmentation=appbuilder.DocumentChoices(choices=["faq"]), + knowledgeAugmentation=appbuilder.DocumentChoices(choices=[ + "faq"]), ), ) self.assertIsInstance(upload_documents_response.documentId, str) - list_res = knowledge.get_documents_list(knowledge_base_id=knowledge_base_id) + list_res = knowledge.get_documents_list( + knowledge_base_id=knowledge_base_id) document_id = list_res.data[-1].id res = knowledge.describe_chunks(document_id) resp = knowledge.create_chunk(document_id, content="test") @@ -136,10 +145,50 @@ def test_create_knowledge_base(self): knowledge_base_id=knowledge_base_id, name="test" ) - if self.whether_create_knowledge_base: knowledge.delete_knowledge_base(knowledge_base_id) + def test_query_knowledge_base(self): + knowledge = appbuilder.KnowledgeBase() + appbuilder.logger.setLoglevel("DEBUG") + client = appbuilder.KnowledgeBase() + request_json = { + "type": "fulltext", + "query": "民法典第三条", + "knowledgebase_ids": ["70c6375a-1595-41f2-9a3b-e81bc9060b7f"], + "metadata_filters": {"filters": [], "condition": "or"}, + "pipeline_config": { + "id": "pipeline_001", + "pipeline": [ + { + "name": "step1", + "type": "elastic_search", + "threshold": 0.1, + "top": 400, + "pre_ranking": { + "bm25_weight": 0.25, + "vec_weight": 0.75, + "bm25_b": 0.75, + "bm25_k1": 1.5, + "bm25_max_score": 50, + }, + }, + { + "name": "step2", + "type": "ranking", + "inputs": ["step1"], + "model_name": "ranker-v1", + "top": 20, + }, + ], + }, + "top": 5, + "skip": 0, + } + request = data_class.QueryKnowledgeBaseRequest(**request_json) + res = client.query_knowledge_base(request) + chunk_id = res.chunks[0].chunk_id + self.assertIsNotNone(chunk_id) if __name__ == "__main__":