From 7f79614724c1fa251fbf317911a4486543d8ea60 Mon Sep 17 00:00:00 2001 From: Mathijs de Bruin Date: Wed, 16 Oct 2024 15:00:29 +0100 Subject: [PATCH] Refine tests and resolve related security issues. --- backend/chainlit/server.py | 18 +++---- backend/tests/conftest.py | 40 ++++++++++------ backend/tests/test_server.py | 92 ++++++++++++++++++++---------------- 3 files changed, 86 insertions(+), 64 deletions(-) diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 11f9178efc..7c4a824b68 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -41,6 +41,7 @@ APIRouter, Depends, FastAPI, + File, Form, HTTPException, Query, @@ -839,11 +840,9 @@ async def delete_thread( @router.post("/project/file") async def upload_file( + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], session_id: str, file: UploadFile, - current_user: Annotated[ - Union[None, User, PersistedUser], Depends(get_current_user) - ], ): """Upload a file to the session files directory.""" @@ -868,20 +867,21 @@ async def upload_file( content = await file.read() + assert file.filename, "No filename for uploaded file" + assert file.content_type, "No content type for uploaded file" + file_response = await session.persist_file( name=file.filename, content=content, mime=file.content_type ) - return JSONResponse(file_response) + return JSONResponse(content=file_response) @router.get("/project/file/{file_id}") async def get_file( file_id: str, session_id: str, - current_user: Annotated[ - Union[None, User, PersistedUser], Depends(get_current_user) - ], + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): """Get a file from the session files directory.""" @@ -891,8 +891,8 @@ async def get_file( if not session: raise HTTPException( - status_code=404, - detail="Session not found", + status_code=401, + detail="Unauthorized", ) if current_user: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 96e1e56957..7cc4266371 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,5 +1,6 @@ import datetime from contextlib import asynccontextmanager +from typing import Callable from unittest.mock import AsyncMock, Mock import pytest @@ -20,21 +21,30 @@ def persisted_test_user(): @pytest.fixture -def mock_session(persisted_test_user: PersistedUser): - mock = Mock(spec=WebsocketSession) - mock.user = persisted_test_user - mock.id = "test_session_id" - mock.user_env = {"test_env": "value"} - mock.chat_settings = {} - mock.chat_profile = None - mock.http_referer = None - mock.client_type = "webapp" - mock.languages = ["en"] - mock.thread_id = "test_thread_id" - mock.emit = AsyncMock() - mock.has_first_interaction = True - - return mock +def mock_session_factory(persisted_test_user: PersistedUser) -> Callable[..., Mock]: + def create_mock_session(**kwargs) -> Mock: + mock = Mock(spec=WebsocketSession) + mock.user = kwargs.get("user", persisted_test_user) + mock.id = kwargs.get("id", "test_session_id") + mock.user_env = kwargs.get("user_env", {"test_env": "value"}) + mock.chat_settings = kwargs.get("chat_settings", {}) + mock.chat_profile = kwargs.get("chat_profile", None) + mock.http_referer = kwargs.get("http_referer", None) + mock.client_type = kwargs.get("client_type", "webapp") + mock.languages = kwargs.get("languages", ["en"]) + mock.thread_id = kwargs.get("thread_id", "test_thread_id") + mock.emit = AsyncMock() + mock.has_first_interaction = kwargs.get("has_first_interaction", True) + mock.files = kwargs.get("files", {}) + + return mock + + return create_mock_session + + +@pytest.fixture +def mock_session(mock_session_factory) -> Mock: + return mock_session_factory() @asynccontextmanager diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 103255917a..23dc4decdc 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -1,7 +1,8 @@ import os from pathlib import Path import pathlib -from unittest.mock import Mock, create_autospec, mock_open +from typing import Callable +from unittest.mock import AsyncMock, Mock, create_autospec, mock_open import datetime # Added import for datetime import pytest @@ -11,6 +12,7 @@ from chainlit.config import APP_ROOT, ChainlitConfig, load_config from chainlit.server import app from fastapi.testclient import TestClient +from chainlit.types import FileReference from chainlit.user import PersistedUser # Added import for PersistedUser @@ -296,18 +298,14 @@ def test_get_file_success( assert response.headers["content-type"].startswith("text/plain") -def test_get_file_not_existing( - test_client: TestClient, monkeypatch: pytest.MonkeyPatch +def test_get_file_not_existent_file( + test_client: TestClient, + mock_session_get_by_id_patched: Mock, ): """ Test retrieval of a non-existing file from a session. """ - # Mock the WebsocketSession.get_by_id method to return None - monkeypatch.setattr( - "chainlit.session.WebsocketSession.get_by_id", lambda session_id: None - ) - # Make the GET request to retrieve the file response = test_client.get("/project/file/test_file_id?session_id=test_session_id") @@ -315,7 +313,7 @@ def test_get_file_not_existing( assert response.status_code == 404 -def test_get_file_unauthorized( +def test_get_file_non_existing_session( test_client: TestClient, tmp_path: pathlib.Path, mock_session_get_by_id_patched: Mock, @@ -348,25 +346,37 @@ def test_upload_file_success( "file": ("test_upload.txt", file_content, "text/plain"), } + # Mock the persist_file method to return a known value + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + # Make the POST request to upload the file response = test_client.post( "/project/file", files=files, - data={"session_id": mock_session_get_by_id_patched.id}, + params={"session_id": mock_session_get_by_id_patched.id}, ) # Verify the response assert response.status_code == 200 response_data = response.json() assert "id" in response_data - file_id = response_data["id"] - - # Verify that the file is stored in the session - assert file_id in mock_session_get_by_id_patched.files - uploaded_file = mock_session_get_by_id_patched.files[file_id] - assert uploaded_file["name"] == "test_upload.txt" - assert uploaded_file["type"] == "text/plain" - assert uploaded_file["size"] == len(file_content) + assert response_data["id"] == expected_file_id + assert response_data["name"] == "test_upload.txt" + assert response_data["type"] == "text/plain" + assert response_data["size"] == len(file_content) + + # Verify that persist_file was called with the correct arguments + mock_session_get_by_id_patched.persist_file.assert_called_once_with( + name="test_upload.txt", content=file_content, mime="text/plain" + ) def test_file_access_by_different_user( @@ -374,6 +384,7 @@ def test_file_access_by_different_user( mock_session_get_by_id_patched: Mock, persisted_test_user: PersistedUser, tmp_path: pathlib.Path, + mock_session_factory: Callable[..., Mock], ): """Test that a file uploaded by one user cannot be accessed by another user.""" @@ -383,30 +394,44 @@ def test_file_access_by_different_user( "file": ("test_upload.txt", file_content, "text/plain"), } + # Mock the persist_file method to return a known value + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + # Make the POST request to upload the file response = test_client.post( "/project/file", files=files, - data={"session_id": mock_session_get_by_id_patched.id}, + params={"session_id": mock_session_get_by_id_patched.id}, ) # Verify the response assert response.status_code == 200 + response_data = response.json() assert "id" in response_data file_id = response_data["id"] - # Create a second mock session with a different user - mock_session_get_by_id_patched.user = PersistedUser( - id="another_user_id", - createdAt=datetime.datetime.now().isoformat(), - identifier="another_user_identifier", + # Create a second session with a different user + second_session = mock_session_factory( + id="another_session_id", + user=PersistedUser( + id="another_user_id", + createdAt=datetime.datetime.now().isoformat(), + identifier="another_user_identifier", + ), ) - mock_session_get_by_id_patched.id = "another_session_id" # Attempt to access the uploaded file using the second user's session response = test_client.get( - f"/project/file/{file_id}?session_id={mock_session_get_by_id_patched.id}" + f"/project/file/{file_id}?session_id={second_session.id}" ) # Verify that the access attempt fails @@ -475,20 +500,7 @@ def test_upload_file_unauthorized( data={"session_id": mock_session_get_by_id_patched.id}, ) - # Assuming that unauthorized upload is allowed if user is not required - # If authorization is required, adjust the expected status code and response accordingly - - assert response.status_code == 200 # Change this if authorization is enforced - response_data = response.json() - assert "id" in response_data - file_id = response_data["id"] - - # Verify that the file is stored in the session - assert file_id in mock_session_get_by_id_patched.files - uploaded_file = mock_session_get_by_id_patched.files[file_id] - assert uploaded_file["name"] == "test_upload.txt" - assert uploaded_file["type"] == "text/plain" - assert uploaded_file["size"] == len(file_content) + assert response.status_code == 422 def test_project_translations_file_path_traversal(