From c8fa6abc41933d5dc6d49faf2b85ea4eed6216f8 Mon Sep 17 00:00:00 2001 From: userpj Date: Thu, 19 Dec 2024 19:25:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E6=A3=80=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go/appbuilder/knowledge_base.go | 36 +++ go/appbuilder/knowledge_base_data.go | 87 ++++++ .../base/config/AppBuilderConfig.java | 2 + .../console/knowledgebase/Knowledgebase.java | 13 + .../QueryKnowledgeBaseRequest.java | 267 ++++++++++++++++++ .../QueryKnowledgeBaseResponse.java | 122 ++++++++ .../core/console/knowledge_base/data_class.py | 115 +++++++- .../console/knowledge_base/knowledge_base.py | 31 +- 8 files changed, 664 insertions(+), 9 deletions(-) create mode 100644 java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java create mode 100644 java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java diff --git a/go/appbuilder/knowledge_base.go b/go/appbuilder/knowledge_base.go index f7ef3ddf8..9938af73e 100644 --- a/go/appbuilder/knowledge_base.go +++ b/go/appbuilder/knowledge_base.go @@ -814,3 +814,39 @@ func (t *KnowledgeBase) DescribeChunks(req DescribeChunksRequest) (DescribeChunk return rsp, nil } + +func (t *KnowledgeBase) QueryKnowledgeBase(req QueryKnowledgeBaseRequest) (QueryKnowledgeBaseResponse, error) { + request := http.Request{} + header := t.sdkConfig.AuthHeaderV2() + serviceURL, err := t.sdkConfig.ServiceURLV2("/v2/knowledgebases/query") + if err != nil { + return QueryKnowledgeBaseResponse{}, err + } + request.URL = serviceURL + request.Method = "POST" + header.Set("Content-Type", "application/json") + request.Header = header + data, _ := json.Marshal(req) + request.Body = NopCloser(bytes.NewReader(data)) + t.sdkConfig.BuildCurlCommand(&request) + resp, err := t.client.Do(&request) + if err != nil { + return QueryKnowledgeBaseResponse{}, err + } + defer resp.Body.Close() + requestID, err := checkHTTPResponse(resp) + if err != nil { + return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + data, err = io.ReadAll(resp.Body) + if err != nil { + return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + + rsp := QueryKnowledgeBaseResponse{} + if err := json.Unmarshal(data, &rsp); err != nil { + return QueryKnowledgeBaseResponse{}, fmt.Errorf("requestID=%s, err=%v", requestID, err) + } + + return rsp, nil +} diff --git a/go/appbuilder/knowledge_base_data.go b/go/appbuilder/knowledge_base_data.go index 4e9ab6d14..6049bc2b9 100644 --- a/go/appbuilder/knowledge_base_data.go +++ b/go/appbuilder/knowledge_base_data.go @@ -14,6 +14,8 @@ package appbuilder +import "time" + const ( ContentTypeRawText = "raw_text" ContentTypeQA = "qa" @@ -245,3 +247,88 @@ type DescribeChunksResponse struct { NextMarker string `json:"nextMarker"` MaxKeys int `json:"maxKeys"` } + +type MetadataFilter struct { + Operator string `json:"operator"` + Field string `json:"field"` + Value interface{} `json:"value"` // 因为Value的类型可以是str或list[str],所以我们使用interface{}来表示任何类型 +} + +type MetadataFilters struct { + Filters []MetadataFilter `json:"filters"` + Condition string `json:"condition"` +} + +type PreRankingConfig struct { + Bm25Weight float64 `json:"bm25_weight"` + VecWeight float64 `json:"vec_weight"` + Bm25B float64 `json:"bm25_b"` + Bm25K1 float64 `json:"bm25_k1"` + Bm25MaxScore float64 `json:"bm25_max_score"` +} + +type ElasticSearchRetrieveConfig struct { + Name string `json:"name"` + Type string `json:"type"` + Threshold float64 `json:"threshold"` + Top int `json:"top"` +} + +type RankingConfig struct { + Name string `json:"name"` + Type string `json:"type"` + Inputs []string `json:"inputs"` + ModelName string `json:"model_name"` + Top int `json:"top"` +} + +type QueryPipelineConfig struct { + ID string `json:"id"` + Pipeline []interface{} `json:"pipeline"` +} + +type QueryKnowledgeBaseRequest struct { + Query string `json:"query"` + Type string `json:"type"` + Top int `json:"top"` + Skip int `json:"skip"` + MetadataFileters MetadataFilters `json:"metadata_fileters"` + PipelineConfig QueryPipelineConfig `json:"pipeline_config"` +} + +type RowLine struct { + Key string `json:"key"` + Index int `json:"index"` + Value string `json:"value"` + EnableIndexing bool `json:"enable_indexing"` + EnableResponse bool `json:"enable_response"` +} + +type ChunkLocation struct { + PageNum []int `json:"paget_num"` + Box [][]int `json:"box"` +} + +type Chunk struct { + ChunkID string `json:"chunk_id"` + KnowledgebaseID string `json:"knowledgebase_id"` + DocumentID string `json:"document_id"` + DocumentName string `json:"document_name"` + Meta map[string]interface{} `json:"meta"` + Type string `json:"type"` + Content string `json:"content"` + CreateTime time.Time `json:"create_time"` + UpdateTime time.Time `json:"update_time"` + RetrievalScore float64 `json:"retrieval_score"` + RankScore float64 `json:"rank_score"` + Locations []ChunkLocation `json:"locations"` + Children []Chunk `json:"children"` +} + +type QueryKnowledgeBaseResponse struct { + RequestId string `json:"requestId"` + Code string `json:"code"` + Message string `json:"message"` + Chunks []Chunk `json:"chunks"` + TotalCount int `json:"total_count"` +} diff --git a/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java b/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java index 33b5b40a1..84b7fe927 100644 --- a/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java +++ b/java/src/main/java/com/baidubce/appbuilder/base/config/AppBuilderConfig.java @@ -72,6 +72,8 @@ public class AppBuilderConfig { public static final String CHUNKS_DESCRIBE_URL = "/knowledgeBase?Action=DescribeChunks"; // 删除切片 public static final String CHUNK_DELETE_URL = "/knowledgeBase?Action=DeleteChunk"; + // 知识库检索 + public static final String QUERY_KNOWLEDGEBASE_URL = "/v2/knowledgebases/query"; // 运行rag diff --git a/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java b/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java index aed73056b..c582fd007 100644 --- a/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java +++ b/java/src/main/java/com/baidubce/appbuilder/console/knowledgebase/Knowledgebase.java @@ -691,4 +691,17 @@ public ChunksDescribeResponse describeChunks(String documentId, String marker, I ChunksDescribeResponse respBody = response.getBody(); return respBody; } + + public QueryKnowledgeBaseResponse queryKnowledgeBaseResponse(QueryKnowledgeBaseRequest request) + throws IOException, AppBuilderServerException { + String url = AppBuilderConfig.QUERY_KNOWLEDGEBASE_URL; + + String jsonBody = JsonUtils.serialize(request); + ClassicHttpRequest postRequest = httpClient.createPostRequestV2(url, + new StringEntity(jsonBody, StandardCharsets.UTF_8)); + postRequest.setHeader("Content-Type", "application/json"); + HttpResponse response = httpClient.execute(postRequest, QueryKnowledgeBaseResponse.class); + QueryKnowledgeBaseResponse respBody = response.getBody(); + return respBody; + } } diff --git a/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java new file mode 100644 index 000000000..042b8812f --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseRequest.java @@ -0,0 +1,267 @@ +package com.baidubce.appbuilder.model.knowledgebase; + +import java.util.List; + +public class QueryKnowledgeBaseRequest { + private String query; + private String type; + private Integer top; + private Integer skip; + private MetadataFilters metadata_filters; + private QueryPipelineConfig pipeline_config; + + public String getQuery() { + return query; + } + + public void setQuery(String query) { + this.query = query; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Integer getTop() { + return top; + } + + public void setTop(Integer top) { + this.top = top; + } + + public Integer getSkip() { + return skip; + } + + public void setSkip(Integer skip) { + this.skip = skip; + } + + public MetadataFilters getMetadata_filters() { + return metadata_filters; + } + + public void setMetadata_filters(MetadataFilters metadata_filters) { + this.metadata_filters = metadata_filters; + } + + public QueryPipelineConfig getPipeline_config() { + return pipeline_config; + } + + public void setPipeline_config(QueryPipelineConfig pipeline_config) { + this.pipeline_config = pipeline_config; + } + + public static class MetadataFilter { + private String operator; + private String field; + private Object value; + + public String getOperator() { + return operator; + } + + public void setOperator(String operator) { + this.operator = operator; + } + + public String getField() { + return field; + } + + public void setField(String field) { + this.field = field; + } + + public Object getValue() { + return value; + } + + public void setValue(Object value) { + this.value = value; + } + } + + public static class MetadataFilters { + private List filters; + private String condition; + + public List getFilters() { + return filters; + } + + public void setFilters(List filters) { + this.filters = filters; + } + + public String getCondition() { + return condition; + } + + public void setCondition(String condition) { + this.condition = condition; + } + } + + public static class PreRankingConfig { + private Double bm25_weight; + private Double vec_weight; + private Double bm25_b; + private Double bm25_k1; + private Double bm25_max_score; + + public Double getBm25_weight() { + return bm25_weight; + } + + public void setBm25_weight(Double bm25_weight) { + this.bm25_weight = bm25_weight; + } + + public Double getVec_weight() { + return vec_weight; + } + + public void setVec_weight(Double vec_weight) { + this.vec_weight = vec_weight; + } + + public Double getBm25_b() { + return bm25_b; + } + + public void setBm25_b(Double bm25_b) { + this.bm25_b = bm25_b; + } + + public Double getBm25_k1() { + return bm25_k1; + } + + public void setBm25_k1(Double bm25_k1) { + this.bm25_k1 = bm25_k1; + } + + public Double getBm25_max_score() { + return bm25_max_score; + } + + public void setBm25_max_score(Double bm25_max_score) { + this.bm25_max_score = bm25_max_score; + } + } + + public static class ElasticSearchRetrieveConfig { + private String name; + private String type; + private Double threshold; + private Integer top; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public Double getThreshold() { + return threshold; + } + + public void setThreshold(Double threshold) { + this.threshold = threshold; + } + + public Integer getTop() { + return top; + } + + public void setTop(Integer top) { + this.top = top; + } + } + + public static class RankingConfig { + private String name; + private String type; + private List inputs; + private String model_name; + private Integer top; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public List getInputs() { + return inputs; + } + + public void setInputs(List inputs) { + this.inputs = inputs; + } + + public String getModel_name() { + return model_name; + } + + public void setModel_name(String model_name) { + this.model_name = model_name; + } + + public Integer getTop() { + return top; + } + + public void setTop(Integer top) { + this.top = top; + } + } + + public static class QueryPipelineConfig { + private String id; + private List pipeline; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public List getPipeline() { + return pipeline; + } + + public void setPipeline(List pipeline) { + this.pipeline = pipeline; + } + } +} diff --git a/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java new file mode 100644 index 000000000..437a33c6b --- /dev/null +++ b/java/src/main/java/com/baidubce/appbuilder/model/knowledgebase/QueryKnowledgeBaseResponse.java @@ -0,0 +1,122 @@ +package com.baidubce.appbuilder.model.knowledgebase; + +import java.time.LocalDateTime; +import java.util.List; +import java.util.Map; + +public class QueryKnowledgeBaseResponse { + private String requestId; + private String code; + private String message; + private List chunks; + private int total_count; + + public String getRequestId() { return requestId; } + + public void setRequestId(String requestId) { this.requestId = requestId; } + + public String getCode() { return code; } + + public void setCode(String code) { this.code = code; } + + public String getMessage() { return message; } + + public void setMessage(String message) { this.message = message; } + + public List getChunks() { return chunks; } + + public void setChunks(List chunks) { this.chunks = chunks; } + + public int getTotal_count() { return total_count; } + + public void setTotal_count(int total_count) { this.total_count = total_count; } + + public static class Chunk { + private String chunk_id; + private String knowledgebase_id; + private String document_id; + private String document_name; + private Map meta; + private String type; + private String content; + private LocalDateTime create_time; + private LocalDateTime update_time; + private float retrieval_score; + private float rank_score; + private List locations; + private List children; + + public String getChunk_id() { return chunk_id; } + + public void setChunk_id(String chunk_id) { this.chunk_id = chunk_id; } + + public String getKnowledgebase_id() { return knowledgebase_id; } + + public void setKnowledgebase_id(String knowledgebase_id) { this.knowledgebase_id = knowledgebase_id; } + + public String getDocument_id() { return document_id; } + + public void setDocument_id(String document_id) { this.document_id = document_id; } + + public String getDocument_name() { return document_name; } + + public void setDocument_name(String document_name) { this.document_name = document_name; } + + public Map getMeta() { return meta; } + + public void setMeta(Map meta) { this.meta = meta; } + + public String getType() { return type; } + + public void setType(String type) { this.type = type; } + + public String getContent() { return content; } + + public void setContent(String content) { this.content = content; } + + public LocalDateTime getCreate_time() { return create_time; } + + public void setCreate_time(LocalDateTime create_time) { this.create_time = create_time; } + + public LocalDateTime getUpdate_time() { return update_time; } + + public void setUpdate_time(LocalDateTime update_time) { this.update_time = update_time; } + + public float getRetrieval_score() { return retrieval_score; } + + public void setRetrieval_score(float retrieval_score) { this.retrieval_score = retrieval_score; } + + public float getRank_score() { return rank_score; } + + public void setRank_score(float rank_score) { this.rank_score = rank_score; } + + public List getLocations() { return locations; } + + public void setLocations(List locations) { this.locations = locations; } + + public List getChildren() { return children; } + + public void setChildren(List children) { this.children = children; } + } + + public static class ChunkLocation { + private List paget_num; + private List> box; + + public List getPaget_num() { + return paget_num; + } + + public void setPaget_num(List paget_num) { + this.paget_num = paget_num; + } + + public List> getBox() { + return box; + } + + public void setBox(List> box) { + this.box = box; + } + } +} diff --git a/python/core/console/knowledge_base/data_class.py b/python/core/console/knowledge_base/data_class.py index ee339f825..587f36dd4 100644 --- a/python/core/console/knowledge_base/data_class.py +++ b/python/core/console/knowledge_base/data_class.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from pydantic import BaseModel -from pydantic import Field -from typing import Union -from typing import Optional -import datetime +from __future__ import annotations +from datetime import datetime +from pydantic import BaseModel, Field +from typing import Union, Optional, List class KnowledgeBaseUploadFileResponse(BaseModel): @@ -285,3 +283,108 @@ class DescribeChunksResponse(BaseModel): ) nextMarker: str = Field(..., description="下一页起始位置") maxKeys: int = Field(..., description="本次查询包含的最大结果集数量") + +class MetadataFilter(BaseModel): + operator: str = Field(..., description="操作符名称。==:等于,in:在数组中,not_in:不在数组中") + field: str = Field(None, description="字段名,目前支持doc_id") + value: Union[str, list[str]] = Field( + ..., description="字段值,如果是in操作符,value为数组" + ) + +class MetadataFilters(BaseModel): + filters: list[MetadataFilter] = Field(..., description="过滤条件") + condition: str = Field(..., description="文档组合条件。and:与,or:或") + +class PreRankingConfig(BaseModel): + bm25_weight: float = Field( + None, description="粗排bm25比重,取值范围在 [0, 1],默认0.75" + ) + vec_weight: float = Field( + None, description="粗排向量余弦分比重,取值范围在 [0, 1],默认0.25" + ) + bm25_b: float = Field( + None, description="控制文档长度对评分影响的参数,取值范围在 [0, 1],默认0.75" + ) + bm25_k1: float = Field( + None, + description="词频饱和因子,控制词频(TF)对评分的影响,常取值范围在 [1.2, 2.0],默认1.5", + ) + bm25_max_score: float = Field( + None, description="得分归一化参数,不建议修改,默认50" + ) + + +class ElasticSearchRetrieveConfig(BaseModel): + name: str = Field(..., description="配置名称") + type: str = Field(None, description="elastic_search标志,该节点为es全文检索") + threshold: float = Field(None, description="得分阈值,默认0.1") + top: int = Field(None, description="召回数量,默认400") + +class RankingConfig(BaseModel): + name: str = Field(..., description="配置名称") + type: str = Field(None, description="ranking标志,该节点为ranking节点") + inputs: list[str] = Field( + ..., + description='输入的节点名,如es检索配置的名称为pipeline_001,则该inputs为["pipeline_001"]', + ) + model_name: str = Field(None, description="ranking模型名(当前仅一种,暂不生效)") + top: int = Field(None, description="取切片top进行排序,默认20,最大400") + +class QueryPipelineConfig(BaseModel): + id: str = Field( + None, description="配置唯一标识,如果用这个id,则引用已经配置好的QueryPipeline" + ) + pipeline: list[Union[ElasticSearchRetrieveConfig, RankingConfig]] = Field( + None, description="配置的Pipeline,如果没有用id,可以用这个对象指定一个新的配置" + ) + + +class QueryKnowledgeBaseRequest(BaseModel): + query: str = Field(..., description="检索query") + type: str = Field(None, description="检索策略的枚举, fulltext:全文检索") + top: int = Field(None, description="返回结果数量") + skip: int = Field( + None, + description="跳过多少条记录, 通过top和skip可以实现类似分页的效果,比如top 10 skip 0,取第一页的10个,top 10 skip 10,取第二页的10个", + ) + metadata_fileters: MetadataFilters = Field(None, description="元数据过滤条件") + pipeline_config: QueryPipelineConfig = Field(None, description="检索配置") + +class RowLine(BaseModel): + key: str = Field(..., description="列名") + index: int = Field(..., description="列号") + value: str = Field(..., description="列值") + enable_indexing: bool = Field(..., description="是否索引") + enable_response: bool = Field( + ..., + description="是否参与问答(即该列数据是否对大模型可见)。当前值固定为true。", + ) + +class ChunkLocation(BaseModel): + paget_num : list[int] = Field(..., description="页面") + box: list[list[int]] = Field( + ..., + description="文本内容位置,在视觉上是文本框,格式是长度为4的int数组,含义是[x, y, width, height]", + ) + +class Chunk(BaseModel): + chunk_id: str = Field(..., description="切片ID") + knowledgebase_id: str = Field(..., description="知识库ID") + document_id: str = Field(..., description="文档ID") + document_name: str = Field(None, description="文档名称") + meta: dict = Field(None, description="文档元数据") + type: str = Field(..., description="切片类型") + content: str = Field(..., description="切片内容") + create_time: datetime = Field(..., description="创建时间") + update_time: datetime = Field(..., description="更新时间") + retrieval_score: float = Field(..., description="粗检索得分") + rank_score: float = Field(..., description="rerank得分") + locations: list[ChunkLocation] = Field(None, description="切片位置") + children: List[Chunk] = Field(None, description="子切片") + +class QueryKnowledgeBaseResponse(BaseModel): + requestId: str = Field(..., description="请求ID") + code: str = Field(None, description="状态码") + message: str = Field(None, description="状态信息") + chunks: list[Chunk] = Field(..., description="切片列表") + total_count: int = Field(..., description="切片总数") diff --git a/python/core/console/knowledge_base/knowledge_base.py b/python/core/console/knowledge_base/knowledge_base.py index b2f6b1775..999e17e6b 100644 --- a/python/core/console/knowledge_base/knowledge_base.py +++ b/python/core/console/knowledge_base/knowledge_base.py @@ -15,9 +15,6 @@ import os import json import uuid -from pydantic import BaseModel -from pydantic import Field -from typing import Union from typing import Optional from appbuilder.core._client import HTTPClient from appbuilder.core.console.knowledge_base import data_class @@ -897,3 +894,31 @@ def get_all_documents(self, knowledge_base_id: Optional[str] = None) -> list: doc_list.extend(response_per_time.data) return doc_list + + def query_knowledge_base( + self, + request: data_class.QueryKnowledgeBaseRequest + ) -> data_class.QueryKnowledgeBaseResponse: + """ + 检索知识库 + + Args: + request (data_class.QueryKnowledgeBaseRequest): 检索知识库的请求对象 + + Returns: + data_class.QueryKnowledgeBaseResponse: 检索知识库的响应对象 + """ + headers = self.http_client.auth_header_v2() + headers["content-type"] = "application/json" + + url = self.http_client.service_url_v2("/v2/knowledgebases/query") + response = self.http_client.session.post( + url=url, headers=headers, json=request.model_dump(exclude_none=True) + ) + + self.http_client.check_response_header(response) + self.http_client.check_console_response(response) + data = response.json() + + resp = data_class.QueryKnowledgeBaseResponse(**data) + return resp