Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a retry mechanism to vector_db service #22

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint pytest
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install -r requirements.dev.txt
- name: Test the code with pytest
run: |
pytest .
Expand Down
3 changes: 2 additions & 1 deletion app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id


@admin_router.post("/upload",
description="Upload image to server. The image will be indexed and stored in the database. If local is set to true, the image will be uploaded to local storage.")
description="Upload image to server. The image will be indexed and stored in the database. If "
"local is set to true, the image will be uploaded to local storage.")
async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")],
model: Annotated[UploadImageModel, Depends()]):
# generate an ID for the image
Expand Down
23 changes: 23 additions & 0 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

import numpy
from grpc.aio import AioRpcError
from httpx import HTTPError
from loguru import logger
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
Expand All @@ -11,6 +13,7 @@
from app.Models.query_params import FilterParams
from app.Models.search_result import SearchResult
from app.config import config
from app.util.retry_deco_async import wrap_object, retry_async


class PointNotFoundError(ValueError):
Expand All @@ -28,6 +31,8 @@ def __init__(self):
self._client = AsyncQdrantClient(host=config.qdrant.host, port=config.qdrant.port,
grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key,
prefer_grpc=config.qdrant.prefer_grpc)

wrap_object(self._client, retry_async((AioRpcError, HTTPError)))
self.collection_name = config.qdrant.coll

async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData:
Expand Down Expand Up @@ -177,6 +182,24 @@ async def get_counts(self, exact: bool) -> int:
resp = await self._client.count(collection_name=self.collection_name, exact=exact)
return resp.count

async def check_collection(self) -> bool:
resp = await self._client.get_collections()
resp = [t.name for t in resp.collections]
return self.collection_name in resp

async def initialize_collection(self):
if await self.check_collection():
logger.warning("Collection already exists. Skip initialization.")
return
logger.info("Initializing database, collection name: {}", self.collection_name)
vectors_config = {
self.IMG_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE),
self.TEXT_VECTOR: models.VectorParams(size=768, distance=models.Distance.COSINE)
}
await self._client.create_collection(collection_name=self.collection_name,
vectors_config=vectors_config)
logger.success("Collection created!")

@classmethod
def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors:
vector = {}
Expand Down
33 changes: 33 additions & 0 deletions app/util/retry_deco_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
import functools
from typing import Callable

from loguru import logger


def retry_async(exceptions=Exception, tries=3, delay=0) -> Callable[[Callable], Callable]:
def deco_retry(f):
@functools.wraps(f)
async def f_retry(*args, **kwargs):
m_tries, m_delay = tries, delay
while m_tries > 1:
try:
return await f(*args, **kwargs)
except exceptions as e:
logger.warning(f"{e}, Retrying in {m_delay} seconds...")
if m_delay > 0:
await asyncio.sleep(m_delay)
m_tries -= 1
return await f(*args, **kwargs)

return f_retry

return deco_retry


def wrap_object(obj: object, deco: Callable[[Callable], Callable]):
for attr in dir(obj):
if not attr.startswith('_'):
attr_val = getattr(obj, attr)
if callable(attr_val) and asyncio.iscoroutinefunction(attr_val):
setattr(obj, attr, deco(getattr(obj, attr)))
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def parse_args():
from scripts import qdrant_create_collection
from app.config import config

qdrant_create_collection.create_coll(config.qdrant.host, config.qdrant.port, config.qdrant.coll)
asyncio.run(qdrant_create_collection.main())

elif args.migrate_from_version is not None:
from scripts import db_migrations
Expand Down
12 changes: 3 additions & 9 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ image search.

## ✈️ Deployment

### Local Deployment
### 🖥️ Local Deployment

#### Deploy Qdrant Database

Expand Down Expand Up @@ -95,15 +95,9 @@ the [online service provided by Qdrant](https://qdrant.tech/documentation/cloud/
is a simple web front-end application for this project. If you want to deploy it, please refer to
its [deployment documentation](https://github.com/hv0905/NekoImageGallery.App).

### Docker Compose Containerized Deployment
### 🐋 Docker Deployment

> [!WARNING]
> Docker compose support is in an alpha state, and may not work for everyone(especially CUDA acceleration).
> Please make sure you are familiar with [Docker documentation](https://docs.docker.com/) before using this deployment
> method.
> If you encounter any problems during deployment, please submit an issue.

#### Prepare `nvidia-container-runtime`
#### Prepare `nvidia-container-runtime` (CUDA users only)

If you want to use CUDA acceleration, you need to install `nvidia-container-runtime` on your system. Please refer to
the [official documentation](https://docs.docker.com/config/containers/resource_constraints/#gpu) for installation.
Expand Down
1 change: 1 addition & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Requirements for development and testing

pytest
pytest-asyncio
pylint
16 changes: 4 additions & 12 deletions scripts/qdrant_create_collection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from qdrant_client import qdrant_client, models
from app.Services.vector_db_context import VectorDbContext


def create_coll(host, port, name):
client = qdrant_client.QdrantClient(host=host, port=port)
# create or update
print("Creating collection")
vectors_config = {
"image_vector": models.VectorParams(size=768, distance=models.Distance.COSINE),
"text_contain_vector": models.VectorParams(size=768, distance=models.Distance.COSINE)
}
client.create_collection(collection_name=name,
vectors_config=vectors_config)
print("Collection created")
async def main():
context = VectorDbContext()
await context.initialize_collection()
45 changes: 45 additions & 0 deletions tests/unit/test_retry_deco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import asyncio

import pytest

from app.util.retry_deco_async import retry_async, wrap_object


class TestRetryDeco:
class ExampleClass:
def __init__(self):
self.counter = 0
self.counter2 = 0

async def example_method(self):
await asyncio.sleep(0)
self.counter += 1
if self.counter < 3:
raise ValueError("Counter is less than 3")
return self.counter

async def example_method_must_raise(self):
await asyncio.sleep(0)
self.counter2 += 1
raise NotImplementedError("This method must raise an exception.")

@pytest.mark.asyncio
async def test_decorator(self):
obj = self.ExampleClass()

@retry_async(tries=3)
def caller():
return obj.example_method()

assert await caller() == 3

@pytest.mark.asyncio
async def test_object_wrapper(self):
obj = self.ExampleClass()
wrap_object(obj, retry_async(ValueError, tries=2))
with pytest.raises(ValueError):
await obj.example_method()
assert await obj.example_method() == 3
with pytest.raises(NotImplementedError):
await obj.example_method_must_raise()
assert obj.counter2 == 1
Loading