Skip to content

Commit

Permalink
Merge pull request #321 from epoch8/add-qdrant-indexes
Browse files Browse the repository at this point in the history
feat: added index_schema to QdrantStore
  • Loading branch information
elephantum authored Jun 26, 2024
2 parents b048f08 + 22bd889 commit 085ac5c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# 0.13.12

* Add processing of an empty response in `QdrantStore`
* Add optional `index_schema` to `QdrantStore`
* Add redis cluster mode support in `RedisStore`

# 0.13.11
Expand Down
98 changes: 85 additions & 13 deletions datapipe/store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,27 @@ class QdrantStore(TableStore):
Args:
name (str): name of the Qdrant collection
url (str): url of the Qdrant server (if using with api_key,
you should explicitly specify port 443, by default qdrant uses 6333)
schema (DataSchema): Describes data that will be stored in the Qdrant collection
pk_field (str): name of the primary key field in the schema, used to identify records
embedding_field (str): name of the field in the schema that contains the vector representation of the record
collection_params (CollectionParams): parameters for creating a collection in Qdrant
url (str): url of the Qdrant server (if using with api_key, you should
explicitly specify port 443, by default qdrant uses 6333)
schema (DataSchema): Describes data that will be stored in the Qdrant
collection
pk_field (str): name of the primary key field in the schema, used to
identify records
embedding_field (str): name of the field in the schema that contains the
vector representation of the record
collection_params (CollectionParams): parameters for creating a
collection in Qdrant
index_schema (dict): {field_name: field_schema} - field(s) in payload
that will be used to create an index on. For data types and field
schema, check
https://qdrant.tech/documentation/concepts/indexing/#payload-index
api_key (Optional[str]): api_key for Qdrant server
"""

Expand All @@ -39,6 +54,7 @@ def __init__(
pk_field: str,
embedding_field: str,
collection_params: CollectionParams,
index_schema: Optional[dict] = None,
api_key: Optional[str] = None,
):
super().__init__()
Expand All @@ -55,13 +71,23 @@ def __init__(
pk_columns = [column for column in self.schema if column.primary_key]

if len(pk_columns) != 1 and pk_columns[0].name != pk_field:
raise ValueError("Incorrect prymary key columns in schema")
raise ValueError("Incorrect primary key columns in schema")

self.paylods_filelds = [
self.payloads_filelds = [
column.name for column in self.schema if column.name != self.embedding_field
]

def __init(self):
self.index_field = {}
if index_schema:
# check if index field is present in schema
for field, field_schema in index_schema.items():
if field not in self.payloads_filelds:
raise ValueError(
f"Index field `{field}` ({field_schema}) not found in payload schema"
)
self.index_field = index_schema

def __init_collection(self):
self.client = QdrantClient(url=self.url, api_key=self._api_key)
try:
self.client.get_collection(self.name)
Expand All @@ -71,9 +97,25 @@ def __init(self):
collection_name=self.name, create_collection=self.collection_params
)

def __init_indexes(self):
"""
Checks on collection's payload indexes and adds them from index_field, if necessary.
Schema checks are not performed.
"""
payload_schema = self.client.get_collection(self.name).payload_schema
for field, field_schema in self.index_field.items():
if field not in payload_schema.keys():
self.client.create_payload_index(
collection_name=self.name,
field_name=field,
field_schema=field_schema,
)

def __check_init(self):
if not self.inited:
self.__init()
self.__init_collection()
if self.index_field:
self.__init_indexes()
self.inited = True

def __get_ids(self, df):
Expand Down Expand Up @@ -107,7 +149,7 @@ def insert_rows(self, df: DataDF) -> None:
vectors=df[self.embedding_field].apply(list).to_list(),
payloads=cast(
List[Dict[str, Any]],
df[self.paylods_filelds].to_dict(orient="records"),
df[self.payloads_filelds].to_dict(orient="records"),
),
),
wait=True,
Expand Down Expand Up @@ -172,6 +214,8 @@ class QdrantShardedStore(TableStore):
schema (DataSchema): Describes data that will be stored in the Qdrant collection
embedding_field (str): name of the field in the schema that contains the vector representation of the record
collection_params (CollectionParams): parameters for creating a collection in Qdrant
index_schema (dict): {field_name: field_schema} - field(s) in payload that will be used to create an index on.
For data types and field schema, check https://qdrant.tech/documentation/concepts/indexing/#payload-index
api_key (Optional[str]): api_key for Qdrant server
"""

Expand All @@ -182,6 +226,7 @@ def __init__(
schema: DataSchema,
embedding_field: str,
collection_params: CollectionParams,
index_schema: Optional[dict] = None,
api_key: Optional[str] = None,
):
super().__init__()
Expand All @@ -196,9 +241,20 @@ def __init__(
self.client: Optional[QdrantClient] = None

self.pk_fields = [column.name for column in self.schema if column.primary_key]
self.paylods_filelds = [
self.payloads_filelds = [
column.name for column in self.schema if column.name != self.embedding_field
]

self.index_field = {}
if index_schema:
# check if index field is present in schema
for field, field_schema in index_schema.items():
if field not in self.payloads_filelds:
raise ValueError(
f"Index field `{field}` ({field_schema}) not found in payload schema"
)
self.index_field = index_schema

self.name_params = re.findall(r"\{([^/]+?)\}", self.name_pattern)

if not len(self.pk_fields):
Expand All @@ -216,12 +272,28 @@ def __init_collection(self, name):
collection_name=name, create_collection=self.collection_params
)

def __init_indexes(self, name):
"""
Checks on collection's payload indexes and adds them from index_field, if necessary.
Schema checks are not performed.
"""
payload_schema = self.client.get_collection(name).payload_schema
for field, field_schema in self.index_field.items():
if field not in payload_schema.keys():
self.client.create_payload_index(
collection_name=name,
field_name=field,
field_schema=field_schema,
)

def __check_init(self, name):
if not self.client:
self.client = QdrantClient(url=self.url, api_key=self._api_key)

if name not in self.inited_collections:
self.__init_collection(name)
if self.index_field:
self.__init_indexes(name)
self.inited_collections.add(name)

def __get_ids(self, df):
Expand Down Expand Up @@ -267,7 +339,7 @@ def insert_rows(self, df: DataDF) -> None:
vectors=gdf[self.embedding_field].apply(list).to_list(),
payloads=cast(
List[Dict[str, Any]],
df[self.paylods_filelds].to_dict(orient="records"),
df[self.payloads_filelds].to_dict(orient="records"),
),
),
wait=True,
Expand Down
16 changes: 14 additions & 2 deletions tests/test_qdrant_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pandas as pd
from qdrant_client.models import Distance, VectorParams
from sqlalchemy import ARRAY, Float, Integer
from sqlalchemy import ARRAY, Float, Integer, String
from sqlalchemy.sql.schema import Column

from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps
Expand All @@ -20,7 +20,9 @@ def extract_id(df: pd.DataFrame) -> pd.DataFrame:


def generate_data() -> Generator[pd.DataFrame, None, None]:
yield pd.DataFrame({"id": [1], "embedding": [[0.1]]})
yield pd.DataFrame(
{"id": [1], "embedding": [[0.1]], "str_payload": ["foo"], "int_payload": [42]}
)


def test_qdrant_table_to_json(dbconn: DBConn, tmp_dir: Path) -> None:
Expand All @@ -34,6 +36,8 @@ def test_qdrant_table_to_json(dbconn: DBConn, tmp_dir: Path) -> None:
schema=[
Column("id", Integer, primary_key=True),
Column("embedding", ARRAY(Float, dimensions=1)),
Column("str_payload", String),
Column("int_payload", Integer),
],
collection_params=CollectionParams(
vectors=VectorParams(
Expand All @@ -43,6 +47,14 @@ def test_qdrant_table_to_json(dbconn: DBConn, tmp_dir: Path) -> None:
),
pk_field="id",
embedding_field="embedding",
index_schema={
"str_payload": "keyword",
"int_payload": {
"type": "integer",
"lookup": False,
"range": True,
},
},
)
),
"output": Table(
Expand Down

0 comments on commit 085ac5c

Please sign in to comment.