-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'dev' into feat/integrate-llm
- Loading branch information
Showing
27 changed files
with
591 additions
and
768 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,149 +1,157 @@ | ||
import asyncio | ||
import os | ||
|
||
import pytest | ||
from fastapi import BackgroundTasks, HTTPException | ||
|
||
from chatsky_ui.api.api_v1.endpoints.bot import ( | ||
_check_process_status, | ||
_stop_process, | ||
check_build_processes, | ||
check_run_processes, | ||
get_build_logs, | ||
get_run_logs, | ||
start_build, | ||
start_run, | ||
) | ||
from dotenv import load_dotenv | ||
from fastapi import status | ||
from httpx import AsyncClient | ||
from httpx._transports.asgi import ASGITransport | ||
|
||
from chatsky_ui.api.deps import get_build_manager, get_run_manager | ||
from chatsky_ui.core.logger_config import get_logger | ||
from chatsky_ui.main import app | ||
from chatsky_ui.schemas.process_status import Status | ||
from chatsky_ui.services.process_manager import RunManager | ||
|
||
PROCESS_ID = 0 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_stop_process_success(mocker): | ||
process_manager = mocker.MagicMock() | ||
process_manager.stop = mocker.AsyncMock() | ||
|
||
# Call the function under test | ||
await _stop_process(PROCESS_ID, process_manager) | ||
|
||
# Assert the stop method was called once with the correct id | ||
process_manager.stop.assert_awaited_once_with(PROCESS_ID) | ||
|
||
|
||
# TODO: take into consideration the errors when process type is build | ||
@pytest.mark.parametrize("error_type", [RuntimeError, ProcessLookupError]) | ||
@pytest.mark.asyncio | ||
async def test_stop_process_error(mocker, error_type): | ||
mock_stop = mocker.AsyncMock(side_effect=error_type) | ||
mocker.patch.object(RunManager, "stop", mock_stop) | ||
|
||
process_type = "run" | ||
|
||
with pytest.raises(HTTPException) as exc_info: | ||
await _stop_process(PROCESS_ID, RunManager(), process_type) | ||
|
||
# Assert the stop method was called once with the correct id | ||
assert exc_info.value.status_code == 404 | ||
mock_stop.assert_awaited_once_with(PROCESS_ID) | ||
|
||
|
||
# TODO: check the errors | ||
@pytest.mark.asyncio | ||
async def test_check_process_status(mocker): | ||
mocked_process_manager = mocker.MagicMock() | ||
mocker.patch.object(mocked_process_manager, "processes", {PROCESS_ID: mocker.MagicMock()}) | ||
mocker.patch.object(mocked_process_manager, "get_status", mocker.AsyncMock(return_value=Status.ALIVE)) | ||
|
||
response = await _check_process_status(PROCESS_ID, mocked_process_manager) | ||
load_dotenv() | ||
|
||
assert response == {"status": "alive"} | ||
mocked_process_manager.get_status.assert_awaited_once_with(0) | ||
BUILD_COMPLETION_TIMEOUT = float(os.getenv("BUILD_COMPLETION_TIMEOUT", 10)) | ||
RUN_RUNNING_TIMEOUT = float(os.getenv("RUN_RUNNING_TIMEOUT", 5)) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_start_build(mocker, dummy_build_id): | ||
build_manager = mocker.MagicMock() | ||
preset = mocker.MagicMock() | ||
|
||
start = mocker.AsyncMock(return_value=dummy_build_id) | ||
mocker.patch.multiple(build_manager, start=start, check_status=mocker.AsyncMock()) | ||
mocker.patch.multiple(preset, wait_time=0, end_status="loop") | ||
|
||
response = await start_build(preset, background_tasks=BackgroundTasks(), build_manager=build_manager) | ||
start.assert_awaited_once_with(preset) | ||
assert response == {"status": "ok", "build_id": dummy_build_id} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_check_build_processes_some_info(mocker, pagination, dummy_build_id): | ||
build_manager = mocker.AsyncMock() | ||
run_manager = mocker.AsyncMock() | ||
|
||
await check_build_processes(dummy_build_id, build_manager, run_manager, pagination) | ||
|
||
build_manager.get_build_info.assert_awaited_once_with(dummy_build_id, run_manager) | ||
@pytest.mark.parametrize( | ||
"preset_status, expected_status", | ||
[("failure", Status.FAILED), ("loop", Status.RUNNING), ("success", Status.COMPLETED)], | ||
) | ||
async def test_start_build(mocker, override_dependency, preset_status, expected_status, start_build_endpoint): | ||
logger = get_logger(__name__) | ||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as async_client: | ||
async with override_dependency(get_build_manager) as process_manager: | ||
process_manager.save_built_script_to_git = mocker.MagicMock() | ||
process_manager.is_changed_graph = mocker.MagicMock(return_value=True) | ||
|
||
response = await async_client.post( | ||
start_build_endpoint, | ||
json={"wait_time": 0.1, "end_status": preset_status}, | ||
) | ||
|
||
assert response.json().get("status") == "ok", "Start process response status is not 'ok'" | ||
|
||
process_id = process_manager.last_id | ||
process = process_manager.processes[process_id] | ||
|
||
try: | ||
await asyncio.wait_for(process.process.wait(), timeout=BUILD_COMPLETION_TIMEOUT) | ||
except asyncio.exceptions.TimeoutError as exc: | ||
if preset_status == "loop": | ||
logger.debug("Loop process timed out. Expected behavior.") | ||
assert True | ||
await process.stop() | ||
return | ||
else: | ||
raise Exception( | ||
f"Process with expected end status '{preset_status}' timed out with " | ||
f"return code '{process.process.returncode}'." | ||
) from exc | ||
|
||
current_status = await process_manager.get_status(process_id) | ||
assert ( | ||
current_status == expected_status | ||
), f"Current process status '{current_status}' did not match the expected '{expected_status}'" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_check_build_processes_all_info(mocker, pagination): | ||
build_id = None | ||
build_manager = mocker.AsyncMock() | ||
run_manager = mocker.AsyncMock() | ||
async def test_stop_build(override_dependency, start_build_endpoint, stop_build_endpoint): | ||
logger = get_logger(__name__) | ||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as async_client: | ||
async with override_dependency(get_build_manager) as manager: | ||
response = await async_client.post( | ||
start_build_endpoint, | ||
json={"wait_time": 0.1, "end_status": "success"}, | ||
) | ||
|
||
await check_build_processes(build_id, build_manager, run_manager, pagination) | ||
assert response.status_code == 201 | ||
logger.debug("Processes: %s", manager.processes) | ||
|
||
build_manager.get_full_info_with_runs_info.assert_awaited_once_with( | ||
run_manager, offset=pagination.offset(), limit=pagination.limit | ||
) | ||
last_id = manager.get_last_id() | ||
logger.debug("Last id: %s, type: %s", last_id, type(last_id)) | ||
logger.debug("Process status %s", await manager.get_status(last_id)) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_build_logs(mocker, pagination, dummy_build_id): | ||
build_manager = mocker.AsyncMock() | ||
|
||
await get_build_logs(dummy_build_id, build_manager, pagination) | ||
|
||
build_manager.fetch_build_logs.assert_awaited_once_with(dummy_build_id, pagination.offset(), pagination.limit) | ||
stop_response = await async_client.get(stop_build_endpoint(str(last_id))) | ||
assert stop_response.status_code == 200 | ||
assert stop_response.json() == {"status": "ok"} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_start_run(mocker, dummy_build_id, dummy_run_id): | ||
run_manager = mocker.MagicMock() | ||
preset = mocker.MagicMock() | ||
|
||
start = mocker.AsyncMock(return_value=dummy_run_id) | ||
mocker.patch.multiple(run_manager, start=start, check_status=mocker.AsyncMock()) | ||
mocker.patch.multiple(preset, wait_time=0, end_status="loop") | ||
|
||
response = await start_run( | ||
build_id=dummy_build_id, preset=preset, background_tasks=BackgroundTasks(), run_manager=run_manager | ||
) | ||
start.assert_awaited_once_with(dummy_build_id, preset) | ||
assert response == {"status": "ok", "run_id": dummy_run_id} | ||
async def test_stop_build_bad_id( | ||
override_dependency, start_run_endpoint, set_working_directory, dummy_build_id, stop_build_endpoint, inexistent_id | ||
): | ||
logger = get_logger(__name__) | ||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as async_client: | ||
async with override_dependency(get_run_manager) as manager: | ||
response = await async_client.post( | ||
start_run_endpoint(dummy_build_id), | ||
json={"wait_time": 0.1, "end_status": "success"}, | ||
) | ||
|
||
assert response.status_code == 201 | ||
logger.debug("Processes: %s", manager.processes) | ||
|
||
stop_response = await async_client.get(stop_build_endpoint(inexistent_id)) | ||
assert stop_response.status_code == status.HTTP_404_NOT_FOUND | ||
assert stop_response.json() == { | ||
"detail": "Process not found. It may have already exited or not started yet. Please check logs." | ||
} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_check_run_processes_some_info(mocker, pagination, dummy_run_id): | ||
run_manager = mocker.AsyncMock() | ||
|
||
await check_run_processes(dummy_run_id, run_manager, pagination) | ||
|
||
run_manager.get_run_info.assert_awaited_once_with(dummy_run_id) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_check_run_processes_all_info(mocker, pagination): | ||
run_id = None | ||
run_manager = mocker.AsyncMock() | ||
|
||
await check_run_processes(run_id, run_manager, pagination) | ||
|
||
run_manager.get_full_info.assert_awaited_once_with(offset=pagination.offset(), limit=pagination.limit) | ||
@pytest.mark.parametrize( | ||
"preset_status, expected_status", [("failure", Status.FAILED), ("loop", Status.RUNNING), ("success", Status.ALIVE)] | ||
) | ||
async def test_start_run(override_dependency, preset_status, expected_status, start_run_endpoint, dummy_build_id): | ||
logger = get_logger(__name__) | ||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as async_client: | ||
async with override_dependency(get_run_manager) as process_manager: | ||
response = await async_client.post( | ||
start_run_endpoint(dummy_build_id), | ||
json={"wait_time": 0.1, "end_status": preset_status}, | ||
) | ||
|
||
assert response.json().get("status") == "ok", "Start process response status is not 'ok'" | ||
|
||
process_id = process_manager.last_id | ||
process = process_manager.processes[process_id] | ||
|
||
try: | ||
await asyncio.wait_for(process.process.wait(), timeout=RUN_RUNNING_TIMEOUT) | ||
except asyncio.exceptions.TimeoutError as exc: | ||
if preset_status == "success": | ||
logger.debug("Success run process timed out. Expected behavior.") | ||
|
||
current_status = await process_manager.get_status(process_id) | ||
assert ( | ||
current_status == expected_status | ||
), f"Current process status '{current_status}' did not match the expected '{expected_status}'" | ||
await process.stop() | ||
elif preset_status == "loop": | ||
logger.debug("Loop process timed out. Expected behavior.") | ||
assert True | ||
await process.stop() | ||
else: | ||
raise Exception( | ||
f"Process with expected end status '{preset_status}' timed out with " | ||
f"return code '{process.process.returncode}'." | ||
) from exc | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_run_logs(mocker, pagination, dummy_run_id): | ||
run_manager = mocker.AsyncMock() | ||
async def test_get_run_logs(run_process, dummy_run_id): | ||
process = await run_process("echo Hello") | ||
process.logger.info("test log") | ||
await process.update_db_info() | ||
|
||
await get_run_logs(dummy_run_id, run_manager, pagination) | ||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as async_client: | ||
get_response = await async_client.get(f"/api/v1/bot/runs/logs/{dummy_run_id}") | ||
|
||
run_manager.fetch_run_logs.assert_awaited_once_with(dummy_run_id, pagination.offset(), pagination.limit) | ||
assert get_response.status_code == 200 | ||
assert any(["test log" in log for log in get_response.json()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pytest | ||
from httpx import AsyncClient | ||
from httpx._transports.asgi import ASGITransport | ||
|
||
from chatsky_ui import __version__ | ||
from chatsky_ui.main import app | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_version(): | ||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as async_client: | ||
response = await async_client.get("/api/v1/config/version") | ||
assert response.status_code == 200 | ||
assert response.json() == __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,20 @@ | ||
# create test flows function here | ||
import pytest | ||
from omegaconf import OmegaConf | ||
from httpx import AsyncClient | ||
from httpx._transports.asgi import ASGITransport | ||
|
||
from chatsky_ui.api.api_v1.endpoints.flows import flows_get, flows_post | ||
from chatsky_ui.main import app | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_flows_get(mocker): | ||
mocker.patch("chatsky_ui.api.api_v1.endpoints.flows.read_conf", return_value=OmegaConf.create({"foo": "bar"})) | ||
response = await flows_get() | ||
assert response["status"] == "ok" | ||
assert response["data"] == {"foo": "bar"} | ||
async def test_flows(dummy_build_id): # noqa: F811 | ||
async with AsyncClient( | ||
transport=ASGITransport(app=app), base_url="http://test", follow_redirects=True | ||
) as async_client: | ||
get_response = await async_client.get("/api/v1/flows", params={"build_id": dummy_build_id}) | ||
print("gettttt", get_response) | ||
assert get_response.status_code == 200 | ||
data = get_response.json()["data"] | ||
assert "flows" in data | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_flows_post(mocker): | ||
mocker.patch("chatsky_ui.api.api_v1.endpoints.flows.write_conf", return_value={}) | ||
response = await flows_post({"foo": "bar"}) | ||
assert response["status"] == "ok" | ||
response = await async_client.post("/api/v1/flows", json=data) | ||
assert response.status_code == 200 |
Oops, something went wrong.