Skip to content

Commit

Permalink
#12 new unit test system
Browse files Browse the repository at this point in the history
  • Loading branch information
PhoenixNazarov committed Aug 8, 2024
1 parent a71e9f0 commit ff46118
Show file tree
Hide file tree
Showing 13 changed files with 244 additions and 125 deletions.
14 changes: 6 additions & 8 deletions server/promptadmin_server/api/job/unit_test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
from promptadmin_server.commons.dto import ViewParamsBuilder, ViewParamsFilter
from promptadmin_server.commons.fastapi.background_task import BackgroundTask
from promptadmin_server.data.entity.sync_data import SyncData
from promptadmin_server.data.service.mapping_entity_service import MappingEntityService
from promptadmin_server.data.service.sync_data_service import SyncDataService


class UnitTestJob(BackgroundTask):
def __init__(self,
sync_data_service: SyncDataService = None,
prompt_unit_test_service: PromptUnitTestService = None
prompt_unit_test_service: PromptUnitTestService = None,
mapping_entity_service: MappingEntityService = None
):
self.sync_data_service = sync_data_service or SyncDataService()
self.prompt_unit_test_service = prompt_unit_test_service or PromptUnitTestService()
self.mapping_entity_service = mapping_entity_service or MappingEntityService()

async def start(self):
await asyncio.sleep(60 * 10)
while True:
view_params = (
ViewParamsBuilder()
.filter(ViewParamsFilter(field=SyncData.test_status, value='wait'))
.build()
)
sync_datas = await self.sync_data_service.find_by_view_params(view_params)
sync_datas = await self.sync_data_service.find_all()
for i in sync_datas:
sync_data = await self.sync_data_service.find_by_id(i.id)
await self.prompt_unit_test_service.process(sync_data)
await self.prompt_unit_test_service.process_sync_data(sync_data, delay=20)
await asyncio.sleep(20)

await asyncio.sleep(60 * 10)
Expand Down
2 changes: 2 additions & 0 deletions server/promptadmin_server/api/routers/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .prompt_audit import router as prompt_audit_router
from .account import router as account_router
from .sync_data import router as sync_data_router
from .unit_test import router as unit_test_router

router = APIRouter()

Expand All @@ -21,3 +22,4 @@
router.include_router(prompt_audit_router, prefix='/prompt_audit')
router.include_router(account_router, prefix='/account')
router.include_router(sync_data_router, prefix='/sync_data')
router.include_router(unit_test_router, prefix='/unit_test')
11 changes: 11 additions & 0 deletions server/promptadmin_server/api/routers/config/unit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fastapi import APIRouter

from promptadmin_server.api.routers.config.base_config_router_factory import bind_view
from promptadmin_server.data.entity.unit_test import UnitTest
from promptadmin_server.data.service.unit_test_service import UnitTestService

router = APIRouter()

unit_test_service = UnitTestService()

bind_view(router, UnitTest, UnitTest, unit_test_service)
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Any

import jinja2
from promptadmin.output.parser_output_service import ParserOutputService
Expand Down Expand Up @@ -46,7 +47,7 @@ async def preview_prompt(self, prompt: Prompt, context: dict[str, str] = None) -
return self.preview(prompt.value, data)

@staticmethod
def preview(prompt: str, context: dict[str, str]):
def preview(prompt: str, context: dict[str, Any]):
environment = jinja2.Environment()
template = environment.from_string(prompt)

Expand Down
95 changes: 55 additions & 40 deletions server/promptadmin_server/api/service/prompt_load_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from promptadmin_server.api.service.user_data import UserData
from promptadmin_server.commons.dto import ViewParamsBuilder, ViewParamsFilter
from promptadmin_server.data.entity.mapping import Mapping
from promptadmin_server.data.entity.mapping_entity import MappingEntity, MappingEntityData
from promptadmin_server.data.entity.mapping_entity import MappingEntity
from promptadmin_server.data.entity.prompt_audit import PromptAudit
from promptadmin_server.data.service.mapping_entity_service import MappingEntityService
from promptadmin_server.data.service.mapping_service import MappingService
from promptadmin_server.data.service.prompt_audit_service import PromptAuditService
from promptadmin_server.data.service.sync_data_service import SyncDataService
from promptadmin_server.data.service.unit_test_service import UnitTestService
from settings import SETTINGS

logger = logging.getLogger(__name__)
Expand All @@ -23,62 +24,72 @@ def __init__(self,
mapping_service: MappingService = None,
prompt_audit_service: PromptAuditService = None,
mapping_entity_service: MappingEntityService = None,
sync_data_service: SyncDataService = None
sync_data_service: SyncDataService = None,
unit_test_service: UnitTestService = None
):
self.mapping_service = mapping_service or MappingService()
self.prompt_audit_service = prompt_audit_service or PromptAuditService()
self.mapping_entity_service = mapping_entity_service or MappingEntityService()
self.sync_data_service = sync_data_service or SyncDataService()
self.unit_test_service = unit_test_service or UnitTestService()

async def load_all(self):
async def load_mapping(mapping: Mapping) -> list[Prompt]:
try:
conn = await asyncpg.connect(SETTINGS.connections[mapping.connection_name])
except Exception as e:
logger.error('Error connection database', exc_info=e)
return []

mapping_name = f', {mapping.field_name}' if mapping.field_name else ''
order_name = f', {mapping.field_order}' if mapping.field_order else ''

row = await conn.fetch(f'SELECT id, {mapping.field} {mapping_name} {order_name} FROM {mapping.table}')

result = []
for i in row:
result.append(
Prompt(
mapping_id=mapping.id,
table=mapping.table,
field=mapping.field,
id=i['id'],
value=i[mapping.field],
name=i.get(mapping.field_name),
sort_value=i.get(mapping.field_order)
)
@staticmethod
async def load_mapping(mapping: Mapping) -> list[Prompt]:
try:
conn = await asyncpg.connect(SETTINGS.connections[mapping.connection_name])
except Exception as e:
logger.error('Error connection database', exc_info=e)
return []

mapping_name = f', {mapping.field_name}' if mapping.field_name else ''
order_name = f', {mapping.field_order}' if mapping.field_order else ''

row = await conn.fetch(f'SELECT id, {mapping.field} {mapping_name} {order_name} FROM {mapping.table}')

result = []
for i in row:
result.append(
Prompt(
mapping_id=mapping.id,
table=mapping.table,
field=mapping.field,
id=i['id'],
value=i[mapping.field],
name=i.get(mapping.field_name),
sort_value=i.get(mapping.field_order)
)
return result
)
return result

async def load_all(self):
mappings = await self.mapping_service.find_all()

res = await asyncio.gather(*[load_mapping(m) for m in mappings])

res = await asyncio.gather(*[self.load_mapping(m) for m in mappings])
out = []
for i in res:
out += i

return out

async def load(self, mapping: Mapping, name: str) -> str:
async def load_mapping_name(self, mapping: Mapping, name: str) -> str:
return await self.load(
mapping.connection_name,
mapping.field,
mapping.table,
mapping.field_name,
name,
)

@staticmethod
async def load(connection_name: str, field: str, table: str, field_name: str, name: str) -> str:
try:
conn = await asyncpg.connect(SETTINGS.connections[mapping.connection_name])
conn = await asyncpg.connect(SETTINGS.connections[connection_name])
except Exception as e:
logger.error('Error connection database', exc_info=e)
return ''

name = f'WHERE {mapping.field_name} = \'{name}\''
name = f'WHERE {field_name} = \'{name}\''

row = await conn.fetch(f'SELECT {mapping.field} FROM {mapping.table} {name}')
return row[0].get(mapping.field, '')
row = await conn.fetch(f'SELECT {field} FROM {table} {name}')
return row[0].get(field, '')

async def save(self, prompt: Prompt, user_data: UserData):
mapping = await self.mapping_service.find_by_table_field(prompt.table, prompt.field)
Expand All @@ -102,13 +113,17 @@ async def save(self, prompt: Prompt, user_data: UserData):
.filter(ViewParamsFilter(field=MappingEntity.connection_name, value=mapping.connection_name))
.filter(ViewParamsFilter(field=MappingEntity.table, value=prompt.table))
.filter(ViewParamsFilter(field=MappingEntity.field, value=prompt.field))
.filter(ViewParamsFilter(field=MappingEntity.name, value=prompt.name))
.filter(ViewParamsFilter(field=MappingEntity.entity, value='sync_data'))
.build()
)
mapping_entities = await self.mapping_entity_service.find_by_view_params(view_params)
if len(mapping_entities) <= 0:
return
sync_data = await self.sync_data_service.find_by_id(mapping_entities[0].entity_id)
sync_data.test_status = 'wait'
await self.sync_data_service.save(sync_data)
if sync_data is None:
return
unit_test = self.unit_test_service.find_by_sync_data_name(sync_data.id, prompt.name)
if unit_test is None:
return
unit_test.test_status = 'wait'
await self.unit_test_service.save(unit_test)
38 changes: 20 additions & 18 deletions server/promptadmin_server/api/service/prompt_sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,29 @@ async def sync_endpoint(self, endpoint: str, secret: str):
async with httpx.AsyncClient() as client:
try:
r = await client.get(endpoint, headers={'Prompt-Admin-Secret': secret})
result = r.json()
app = result['app']
prompt_service_info = result['prompt_service_info']
for i in prompt_service_info:
await self.sync(
app=app,
table=i['table'],
field=i['field'],
field_name=i['field_name'],
name=i['name'],
service_model_info=i['service_model_info'],
template_context_type=i['template_context_type'],
template_context_default=i['template_context_default'],
history_context_default=i['history_context_default'],
parsed_model_type=i['parsed_model_type'],
parsed_model_default=i['parsed_model_default'],
fail_parse_model_strategy=i['fail_parse_model_strategy']
)
await self.sync_json(r.json())
except Exception as e:
logger.error('Sync exception', exc_info=e)

async def sync_json(self, result: dict):
app = result['app']
prompt_service_info = result['prompt_service_info']
for i in prompt_service_info:
await self.sync(
app=app,
table=i['table'],
field=i['field'],
field_name=i['field_name'],
name=i['name'],
service_model_info=i['service_model_info'],
template_context_type=i['template_context_type'],
template_context_default=i['template_context_default'],
history_context_default=i['history_context_default'],
parsed_model_type=i['parsed_model_type'],
parsed_model_default=i['parsed_model_default'],
fail_parse_model_strategy=i['fail_parse_model_strategy']
)

async def sync(
self,
app: str,
Expand Down
Loading

0 comments on commit ff46118

Please sign in to comment.