Skip to content

Commit

Permalink
Merge pull request #24 from hv0905/itest
Browse files Browse the repository at this point in the history
Add integrate test
  • Loading branch information
hv0905 authored May 10, 2024
2 parents d54c52f + 9f488ec commit c46319d
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 27 deletions.
27 changes: 14 additions & 13 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ def __init__(self):
grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key,
prefer_grpc=config.qdrant.prefer_grpc)
wrap_object(self._client, retry_async((AioRpcError, HTTPError)))
self._local = False
case QdrantMode.LOCAL:
self._client = AsyncQdrantClient(path=config.qdrant.local_path)
self._local = True
case QdrantMode.MEMORY:
logger.warning("Using in-memory Qdrant client. Data will be lost after application restart. "
"This should only be used for testing and debugging.")
self._client = AsyncQdrantClient(":memory:")
self._local = True
case _:
raise ValueError("Invalid Qdrant mode.")
self.collection_name = config.qdrant.coll
Expand Down Expand Up @@ -234,24 +237,22 @@ def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct:
vector=cls._get_vector_from_img_data(img_data).vector
)

@classmethod
def _get_img_data_from_point(cls, point: AVAILABLE_POINT_TYPES) -> ImageData:
def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData:
return (ImageData
.from_payload(point.id,
point.payload,
image_vector=numpy.array(point.vector[cls.IMG_VECTOR], dtype=numpy.float32)
if point.vector and cls.IMG_VECTOR in point.vector else None,
text_contain_vector=numpy.array(point.vector[cls.TEXT_VECTOR], dtype=numpy.float32)
if point.vector and cls.TEXT_VECTOR in point.vector else None
# workaround: https://github.com/qdrant/qdrant-client/issues/624
point.payload.copy() if self._local else point.payload,
image_vector=numpy.array(point.vector[self.IMG_VECTOR], dtype=numpy.float32)
if point.vector and self.IMG_VECTOR in point.vector else None,
text_contain_vector=numpy.array(point.vector[self.TEXT_VECTOR], dtype=numpy.float32)
if point.vector and self.TEXT_VECTOR in point.vector else None
))

@classmethod
def _get_img_data_from_points(cls, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]:
return [cls._get_img_data_from_point(t) for t in points]
def _get_img_data_from_points(self, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]:
return [self._get_img_data_from_point(t) for t in points]

@classmethod
def _get_search_result_from_scored_point(cls, point: models.ScoredPoint) -> SearchResult:
return SearchResult(img=cls._get_img_data_from_point(point), score=point.score)
def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> SearchResult:
return SearchResult(img=self._get_img_data_from_point(point), score=point.score)

@classmethod
def getVectorByBasis(cls, basis: SearchBasisEnum) -> str:
Expand Down
26 changes: 26 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from fastapi.testclient import TestClient

from app import config

TEST_ACCESS_TOKEN = 'test_token'
TEST_ADMIN_TOKEN = 'test_admin_token'

config.config.qdrant.mode = "memory"
config.config.admin_api_enable = True
config.config.access_protected = True
config.config.access_token = TEST_ACCESS_TOKEN
config.config.admin_token = TEST_ADMIN_TOKEN
config.config.storage.method = config.StorageMode.LOCAL


@pytest.fixture(scope="session")
def test_client(tmp_path_factory) -> TestClient:
# Modify the configuration for testing
config.config.storage.local.path = tmp_path_factory.mktemp("static_files")

from app.webapp import app
# Start the application

with TestClient(app) as client:
yield client
68 changes: 68 additions & 0 deletions tests/api/integrate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio
from pathlib import Path

import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN

assets_path = Path(__file__).parent / '..' / 'assets'

test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'],
'cat': ['cat_0.jpg', 'cat_1.jpg'],
'cg': ['cg_0.jpg', 'cg_1.png']}


@pytest.mark.asyncio
async def test_integrate(test_client):
credentials = {'x-admin-token': TEST_ADMIN_TOKEN, 'x-access-token': TEST_ACCESS_TOKEN}
resp = test_client.get("/", headers=credentials)
assert resp.status_code == 200
img_ids = dict()
for img_cls in test_images:
img_ids[img_cls] = []
for image in test_images[img_cls]:
print(f'upload image {image}...')
resp = test_client.post('/admin/upload',
files={'image_file': open(assets_path / 'test_images' / image, 'rb')},
headers=credentials,
params={'local': True})
assert resp.status_code == 200
img_ids[img_cls].append(resp.json()['image_id'])

print('Waiting for images to be processed...')

while True:
resp = test_client.get('/admin/server_info', headers=credentials)
if resp.json()['image_count'] >= 7:
break
await asyncio.sleep(1)

resp = test_client.get('/search/text/hatsune+miku',
headers=credentials)
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['cg']

resp = test_client.post('/search/image',
files={'image': open(assets_path / 'test_images' / test_images['cat'][0], 'rb')},
headers=credentials)

assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['cat']

resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}",
headers=credentials)

assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']

resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True},
headers=credentials)
assert resp.status_code == 200

resp = test_client.get(f"/search/text/cat", params={'categories': 'bsn'}, headers=credentials)
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']

resp = test_client.get(f"/search/text/cat", params={'starred': True}, headers=credentials)
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']
20 changes: 6 additions & 14 deletions tests/api/test_home.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from fastapi.testclient import TestClient

from app.config import config
from app.webapp import app

client = TestClient(app)
Expand All @@ -14,29 +13,22 @@ def anyio_backend():

class TestHome:

# noinspection PyMethodMayBeStatic
def setup_class(self):
config.admin_api_enable = True
config.access_protected = True
config.access_token = 'test_token'
config.admin_token = 'test_admin_token'

def test_get_home_no_tokens(self):
response = client.get("/")
def test_get_home_no_tokens(self, test_client):
response = test_client.get("/")
assert response.status_code == 200
assert response.json()['authorization']['required']
assert not response.json()['authorization']['passed']
assert response.json()['admin_api']['available']
assert not response.json()['admin_api']['passed']

def test_get_home_access_token(self):
response = client.get("/", headers={'x-access-token': 'test_token'})
def test_get_home_access_token(self, test_client):
response = test_client.get("/", headers={'x-access-token': 'test_token'})
assert response.status_code == 200
assert response.json()['authorization']['required']
assert response.json()['authorization']['passed']

def test_get_home_admin_token(self):
response = client.get("/", headers={'x-admin-token': 'test_admin_token', 'x-access-token': 'test_token'})
def test_get_home_admin_token(self, test_client):
response = test_client.get("/", headers={'x-admin-token': 'test_admin_token', 'x-access-token': 'test_token'})
assert response.status_code == 200
assert response.json()['admin_api']['available']
assert response.json()['admin_api']['passed']
Expand Down
Binary file added tests/assets/test_images/bsn_0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/test_images/bsn_1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/test_images/bsn_2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/test_images/cg_0.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/assets/test_images/cg_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c46319d

Please sign in to comment.