Skip to content

Commit

Permalink
Merge pull request #22 from hv0905/vector_db_retry
Browse files Browse the repository at this point in the history
Add a retry mechanism to vector_db service
  • Loading branch information
hv0905 authored May 7, 2024
2 parents edf605a + 7579792 commit b506d31
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint pytest
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install -r requirements.dev.txt
- name: Test the code with pytest
run: |
pytest .
Expand Down
3 changes: 2 additions & 1 deletion app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id


@admin_router.post("/upload",
description="Upload image to server. The image will be indexed and stored in the database. If local is set to true, the image will be uploaded to local storage.")
description="Upload image to server. The image will be indexed and stored in the database. If "
"local is set to true, the image will be uploaded to local storage.")
async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")],
model: Annotated[UploadImageModel, Depends()]):
# generate an ID for the image
Expand Down
23 changes: 23 additions & 0 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import numpy
from grpc.aio import AioRpcError
from httpx import HTTPError
from loguru import logger
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
Expand All @@ -11,6 +13,7 @@
from app.Models.query_params import FilterParams
from app.Models.search_result import SearchResult
from app.config import config
from app.util.retry_deco_async import wrap_object, retry_async


class PointNotFoundError(ValueError):
Expand All @@ -28,6 +31,8 @@ def __init__(self):
self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port,
grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key,
prefer_grpc=config.qdrant.prefer_grpc)

wrap_object(self._client, retry_async((AioRpcError, HTTPError)))
self.collection_name = config.qdrant.coll

async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
Expand Down Expand Up @@ -177,6 +182,24 @@ async def get_counts(self, exact: bool) -> int:
resp = await self._client.count(collection_name=self.collection_name, exact=exact)
return resp.count

async def check_collection(self) -> bool:
resp = await self._client.get_collections()
resp = [t.name for t in resp.collections]
return self.collection_name in resp

async def initialize_collection(self):
if await self.check_collection():
logger.warning("Collection already exists. Skip initialization.")
return
logger.info("Initializing database, collection name: {}", self.collection_name)
vectors_config = {
self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE),
self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE)
}
await self._client.create_collection(collection_name=self.collection_name,
vectors_config=vectors_config)
logger.success("Collection created!")

@classmethod
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
vector = {}
Expand Down
33 changes: 33 additions & 0 deletions app/util/retry_deco_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
import functools
from typing import Callable

from loguru import logger


def retry_async(exceptions=Exception, tries=3, delay=0) -> Callable[[Callable], Callable]:
def deco_retry(f):
@functools.wraps(f)
async def f_retry(*args, **kwargs):
m_tries, m_delay = tries, delay
while m_tries > 1:
try:
return await f(*args, **kwargs)
except exceptions as e:
logger.warning(f"{e}, Retrying in {m_delay} seconds...")
if m_delay > 0:
await asyncio.sleep(m_delay)
m_tries -= 1
return await f(*args, **kwargs)

return f_retry

return deco_retry


def wrap_object(obj: object, deco: Callable[[Callable], Callable]):
for attr in dir(obj):
if not attr.startswith('_'):
attr_val = getattr(obj, attr)
if callable(attr_val) and asyncio.iscoroutinefunction(attr_val):
setattr(obj, attr, deco(getattr(obj, attr)))
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def parse_args():
from scripts import qdrant_create_collection
from app.config import config

qdrant_create_collection.create_coll(config.qdrant.host, config.qdrant.port, config.qdrant.coll)
asyncio.run(qdrant_create_collection.main())

elif args.migrate_from_version is not None:
from scripts import db_migrations
Expand Down
12 changes: 3 additions & 9 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ image search.
## ✈️ Deployment

### Local Deployment
### 🖥️ Local Deployment

#### Deploy Qdrant Database

Expand Down Expand Up @@ -95,15 +95,9 @@ the [online service provided by Qdrant](https://qdrant.tech/documentation/cloud/
is a simple web front-end application for this project. If you want to deploy it, please refer to
its [deployment documentation](https://github.com/hv0905/NekoImageGallery.App).
### Docker Compose Containerized Deployment
### 🐋 Docker Deployment
> [!WARNING]
> Docker compose support is in an alpha state, and may not work for everyone(especially CUDA acceleration).
> Please make sure you are familiar with [Docker documentation](https://docs.docker.com/) before using this deployment
> method.
> If you encounter any problems during deployment, please submit an issue.
#### Prepare `nvidia-container-runtime`
#### Prepare `nvidia-container-runtime` (CUDA users only)
If you want to use CUDA acceleration, you need to install `nvidia-container-runtime` on your system. Please refer to
the [official documentation](https://docs.docker.com/config/containers/resource_constraints/#gpu) for installation.
Expand Down
1 change: 1 addition & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Requirements for development and testing

pytest
pytest-asyncio
pylint
16 changes: 4 additions & 12 deletions scripts/qdrant_create_collection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from qdrant_client import qdrant_client, models
from app.Services.vector_db_context import VectorDbContext


def create_coll(host, port, name):
client = qdrant_client.QdrantClient(host=host, port=port)
# create or update
print("Creating collection")
vectors_config = {
"image_vector": models.VectorParams(size=768, distance=models.Distance.COSINE),
"text_contain_vector": models.VectorParams(size=768, distance=models.Distance.COSINE)
}
client.create_collection(collection_name=name,
vectors_config=vectors_config)
print("Collection created")
async def main():
context = VectorDbContext()
await context.initialize_collection()
45 changes: 45 additions & 0 deletions tests/unit/test_retry_deco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import asyncio

import pytest

from app.util.retry_deco_async import retry_async, wrap_object


class TestRetryDeco:
class ExampleClass:
def __init__(self):
self.counter = 0
self.counter2 = 0

async def example_method(self):
await asyncio.sleep(0)
self.counter += 1
if self.counter < 3:
raise ValueError("Counter is less than 3")
return self.counter

async def example_method_must_raise(self):
await asyncio.sleep(0)
self.counter2 += 1
raise NotImplementedError("This method must raise an exception.")

@pytest.mark.asyncio
async def test_decorator(self):
obj = self.ExampleClass()

@retry_async(tries=3)
def caller():
return obj.example_method()

assert await caller() == 3

@pytest.mark.asyncio
async def test_object_wrapper(self):
obj = self.ExampleClass()
wrap_object(obj, retry_async(ValueError, tries=2))
with pytest.raises(ValueError):
await obj.example_method()
assert await obj.example_method() == 3
with pytest.raises(NotImplementedError):
await obj.example_method_must_raise()
assert obj.counter2 == 1

0 comments on commit b506d31

Please sign in to comment.