From fa1d8969ebf6713dab21df6110b676029fc119c6 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Wed, 8 May 2024 01:21:25 +0800 Subject: [PATCH 1/2] Add a retry mechanism for vector database --- .github/workflows/test_lint.yml | 2 +- app/Controllers/admin.py | 3 ++- app/Services/vector_db_context.py | 5 ++++ app/util/retry_deco_async.py | 33 +++++++++++++++++++++++ readme.md | 12 +++------ requirements.dev.txt | 1 + tests/unit/test_retry_deco.py | 45 +++++++++++++++++++++++++++++++ 7 files changed, 90 insertions(+), 11 deletions(-) create mode 100644 app/util/retry_deco_async.py create mode 100644 tests/unit/test_retry_deco.py diff --git a/.github/workflows/test_lint.yml b/.github/workflows/test_lint.yml index 9818d3c..48cf96f 100644 --- a/.github/workflows/test_lint.yml +++ b/.github/workflows/test_lint.yml @@ -23,9 +23,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pylint pytest pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -r requirements.txt + pip install -r requirements.dev.txt - name: Test the code with pytest run: | pytest . diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index 42efa92..1a024b9 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -85,7 +85,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id @admin_router.post("/upload", - description="Upload image to server. The image will be indexed and stored in the database. If local is set to true, the image will be uploaded to local storage.") + description="Upload image to server. The image will be indexed and stored in the database. If " + "local is set to true, the image will be uploaded to local storage.") async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")], model: Annotated[UploadImageModel, Depends()]): # generate an ID for the image diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 4ff428b..21c392a 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -1,6 +1,8 @@ from typing import Optional import numpy +from grpc.aio import AioRpcError +from httpx import HTTPError from loguru import logger from qdrant_client import AsyncQdrantClient from qdrant_client.http import models @@ -11,6 +13,7 @@ from app.Models.query_params import FilterParams from app.Models.search_result import SearchResult from app.config import config +from app.util.retry_deco_async import wrap_object, retry_async class PointNotFoundError(ValueError): @@ -28,6 +31,8 @@ def __init__(self): self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port, grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key, prefer_grpc=config.qdrant.prefer_grpc) + + wrap_object(self._client, retry_async((AioRpcError, HTTPError))) self.collection_name = config.qdrant.coll async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: diff --git a/app/util/retry_deco_async.py b/app/util/retry_deco_async.py new file mode 100644 index 0000000..c818bed --- /dev/null +++ b/app/util/retry_deco_async.py @@ -0,0 +1,33 @@ +import asyncio +import functools +from typing import Callable + +from loguru import logger + + +def retry_async(exceptions=Exception, tries=3, delay=0) -> Callable[[Callable], Callable]: + def deco_retry(f): + @functools.wraps(f) + async def f_retry(*args, **kwargs): + m_tries, m_delay = tries, delay + while m_tries > 1: + try: + return await f(*args, **kwargs) + except exceptions as e: + logger.warning(f"{e}, Retrying in {m_delay} seconds...") + if m_delay > 0: + await asyncio.sleep(m_delay) + m_tries -= 1 + return await f(*args, **kwargs) + + return f_retry + + return deco_retry + + +def wrap_object(obj: object, deco: Callable[[Callable], Callable]): + for attr in dir(obj): + if not attr.startswith('_'): + attr_val = getattr(obj, attr) + if callable(attr_val) and asyncio.iscoroutinefunction(attr_val): + setattr(obj, attr, deco(getattr(obj, attr))) diff --git a/readme.md b/readme.md index ebbacb5..f3a909c 100644 --- a/readme.md +++ b/readme.md @@ -32,7 +32,7 @@ image search. ## ✈️ Deployment -### Local Deployment +### 🖥️ Local Deployment #### Deploy Qdrant Database @@ -95,15 +95,9 @@ the [online service provided by Qdrant](https://qdrant.tech/documentation/cloud/ is a simple web front-end application for this project. If you want to deploy it, please refer to its [deployment documentation](https://github.com/hv0905/NekoImageGallery.App). -### Docker Compose Containerized Deployment +### 🐋 Docker Deployment -> [!WARNING] -> Docker compose support is in an alpha state, and may not work for everyone(especially CUDA acceleration). -> Please make sure you are familiar with [Docker documentation](https://docs.docker.com/) before using this deployment -> method. -> If you encounter any problems during deployment, please submit an issue. - -#### Prepare `nvidia-container-runtime` +#### Prepare `nvidia-container-runtime` (CUDA users only) If you want to use CUDA acceleration, you need to install `nvidia-container-runtime` on your system. Please refer to the [official documentation](https://docs.docker.com/config/containers/resource_constraints/#gpu) for installation. diff --git a/requirements.dev.txt b/requirements.dev.txt index 04bb894..8be78ef 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,4 +1,5 @@ # Requirements for development and testing pytest +pytest-asyncio pylint \ No newline at end of file diff --git a/tests/unit/test_retry_deco.py b/tests/unit/test_retry_deco.py new file mode 100644 index 0000000..f9d47e7 --- /dev/null +++ b/tests/unit/test_retry_deco.py @@ -0,0 +1,45 @@ +import asyncio + +import pytest + +from app.util.retry_deco_async import retry_async, wrap_object + + +class TestRetryDeco: + class ExampleClass: + def __init__(self): + self.counter = 0 + self.counter2 = 0 + + async def example_method(self): + await asyncio.sleep(0) + self.counter += 1 + if self.counter < 3: + raise ValueError("Counter is less than 3") + return self.counter + + async def example_method_must_raise(self): + await asyncio.sleep(0) + self.counter2 += 1 + raise NotImplementedError("This method must raise an exception.") + + @pytest.mark.asyncio + async def test_decorator(self): + obj = self.ExampleClass() + + @retry_async(tries=3) + def caller(): + return obj.example_method() + + assert await caller() == 3 + + @pytest.mark.asyncio + async def test_object_wrapper(self): + obj = self.ExampleClass() + wrap_object(obj, retry_async(ValueError, tries=2)) + with pytest.raises(ValueError): + await obj.example_method() + assert await obj.example_method() == 3 + with pytest.raises(NotImplementedError): + await obj.example_method_must_raise() + assert obj.counter2 == 1 From 7579792cb35575eb3bc6b941255399bb929d4775 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Wed, 8 May 2024 01:22:13 +0800 Subject: [PATCH 2/2] Move 'create_collection' to vector db service --- app/Services/vector_db_context.py | 18 ++++++++++++++++++ main.py | 2 +- scripts/qdrant_create_collection.py | 16 ++++------------ 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 21c392a..36fc7f8 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -182,6 +182,24 @@ async def get_counts(self, exact: bool) -> int: resp = await self._client.count(collection_name=self.collection_name, exact=exact) return resp.count + async def check_collection(self) -> bool: + resp = await self._client.get_collections() + resp = [t.name for t in resp.collections] + return self.collection_name in resp + + async def initialize_collection(self): + if await self.check_collection(): + logger.warning("Collection already exists. Skip initialization.") + return + logger.info("Initializing database, collection name: {}", self.collection_name) + vectors_config = { + self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE), + self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE) + } + await self._client.create_collection(collection_name=self.collection_name, + vectors_config=vectors_config) + logger.success("Collection created!") + @classmethod def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: vector = {} diff --git a/main.py b/main.py index c1532a3..84c60ac 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ def parse_args(): from scripts import qdrant_create_collection from app.config import config - qdrant_create_collection.create_coll(config.qdrant.host, config.qdrant.port, config.qdrant.coll) + asyncio.run(qdrant_create_collection.main()) elif args.migrate_from_version is not None: from scripts import db_migrations diff --git a/scripts/qdrant_create_collection.py b/scripts/qdrant_create_collection.py index df7d89c..3095485 100644 --- a/scripts/qdrant_create_collection.py +++ b/scripts/qdrant_create_collection.py @@ -1,14 +1,6 @@ -from qdrant_client import qdrant_client, models +from app.Services.vector_db_context import VectorDbContext -def create_coll(host, port, name): - client = qdrant_client.QdrantClient(host=host, port=port) - # create or update - print("Creating collection") - vectors_config = { - "image_vector": models.VectorParams(size=768, distance=models.Distance.COSINE), - "text_contain_vector": models.VectorParams(size=768, distance=models.Distance.COSINE) - } - client.create_collection(collection_name=name, - vectors_config=vectors_config) - print("Collection created") +async def main(): + context = VectorDbContext() + await context.initialize_collection()