diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index b4bf0dd989..7c4a824b68 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -41,6 +41,7 @@ APIRouter, Depends, FastAPI, + File, Form, HTTPException, Query, @@ -839,11 +840,9 @@ async def delete_thread( @router.post("/project/file") async def upload_file( + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], session_id: str, file: UploadFile, - current_user: Annotated[ - Union[None, User, PersistedUser], Depends(get_current_user) - ], ): """Upload a file to the session files directory.""" @@ -868,17 +867,21 @@ async def upload_file( content = await file.read() + assert file.filename, "No filename for uploaded file" + assert file.content_type, "No content type for uploaded file" + file_response = await session.persist_file( name=file.filename, content=content, mime=file.content_type ) - return JSONResponse(file_response) + return JSONResponse(content=file_response) @router.get("/project/file/{file_id}") async def get_file( file_id: str, - session_id: Optional[str] = None, + session_id: str, + current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], ): """Get a file from the session files directory.""" @@ -888,10 +891,17 @@ async def get_file( if not session: raise HTTPException( - status_code=404, - detail="Session not found", + status_code=401, + detail="Unauthorized", ) + if current_user: + if not session.user or session.user.identifier != current_user.identifier: + raise HTTPException( + status_code=401, + detail="You are not authorized to download files from this session", + ) + if file_id in session.files: file = session.files[file_id] return FileResponse(file["path"], media_type=file["type"]) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 94ae4a596a..7cc4266371 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,5 +1,6 @@ import datetime from contextlib import asynccontextmanager +from typing import Callable from unittest.mock import AsyncMock, Mock import pytest @@ -20,20 +21,30 @@ def persisted_test_user(): @pytest.fixture -def mock_session(): - mock = Mock(spec=WebsocketSession) - mock.id = "test_session_id" - mock.user_env = {"test_env": "value"} - mock.chat_settings = {} - mock.chat_profile = None - mock.http_referer = None - mock.client_type = "webapp" - mock.languages = ["en"] - mock.thread_id = "test_thread_id" - mock.emit = AsyncMock() - mock.has_first_interaction = True - - return mock +def mock_session_factory(persisted_test_user: PersistedUser) -> Callable[..., Mock]: + def create_mock_session(**kwargs) -> Mock: + mock = Mock(spec=WebsocketSession) + mock.user = kwargs.get("user", persisted_test_user) + mock.id = kwargs.get("id", "test_session_id") + mock.user_env = kwargs.get("user_env", {"test_env": "value"}) + mock.chat_settings = kwargs.get("chat_settings", {}) + mock.chat_profile = kwargs.get("chat_profile", None) + mock.http_referer = kwargs.get("http_referer", None) + mock.client_type = kwargs.get("client_type", "webapp") + mock.languages = kwargs.get("languages", ["en"]) + mock.thread_id = kwargs.get("thread_id", "test_thread_id") + mock.emit = AsyncMock() + mock.has_first_interaction = kwargs.get("has_first_interaction", True) + mock.files = kwargs.get("files", {}) + + return mock + + return create_mock_session + + +@pytest.fixture +def mock_session(mock_session_factory) -> Mock: + return mock_session_factory() @asynccontextmanager diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 6aa9c16d5a..36c65124d6 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -1,12 +1,19 @@ import os from pathlib import Path -from unittest.mock import Mock, create_autospec, mock_open +import pathlib +from typing import Callable +from unittest.mock import AsyncMock, Mock, create_autospec, mock_open +import datetime # Added import for datetime import pytest +import tempfile +from chainlit.session import WebsocketSession from chainlit.auth import get_current_user from chainlit.config import APP_ROOT, ChainlitConfig, load_config from chainlit.server import app from fastapi.testclient import TestClient +from chainlit.types import FileReference +from chainlit.user import PersistedUser # Added import for PersistedUser @pytest.fixture @@ -219,7 +226,7 @@ def test_get_avatar_non_existent_favicon( def test_avatar_path_traversal( - test_client: TestClient, monkeypatch: pytest.MonkeyPatch, tmp_path + test_client: TestClient, monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path ): """Test to prevent potential path traversal in avatar route on Windows.""" @@ -240,6 +247,268 @@ def test_avatar_path_traversal( assert response.status_code == 400 +@pytest.fixture +def mock_session_get_by_id_patched(mock_session: Mock, monkeypatch: pytest.MonkeyPatch): + test_session_id = "test_session_id" + + # Mock the WebsocketSession.get_by_id method to return the mock session + monkeypatch.setattr( + "chainlit.session.WebsocketSession.get_by_id", + lambda session_id: mock_session if session_id == test_session_id else None, + ) + + return mock_session + + +def test_get_file_success( + test_client: TestClient, + mock_session_get_by_id_patched: Mock, + tmp_path: pathlib.Path, + mock_get_current_user: Mock, +): + """ + Test successful retrieval of a file from a session. + """ + # Set current_user to match session.user + mock_get_current_user.return_value = mock_session_get_by_id_patched.user + + # Create test data + test_content = b"Test file content" + test_file_id = "test_file_id" + + # Create a temporary file with the test content + test_file = tmp_path / "test_file" + test_file.write_bytes(test_content) + + mock_session_get_by_id_patched.files = { + test_file_id: { + "id": test_file_id, + "path": test_file, + "name": "test.txt", + "type": "text/plain", + "size": len(test_content), + } + } + + # Make the GET request to retrieve the file + response = test_client.get( + f"/project/file/{test_file_id}?session_id={mock_session_get_by_id_patched.id}" + ) + + # Verify the response + assert response.status_code == 200 + assert response.content == test_content + assert response.headers["content-type"].startswith("text/plain") + + +def test_get_file_not_existent_file( + test_client: TestClient, + mock_session_get_by_id_patched: Mock, + mock_get_current_user: Mock, +): + """ + Test retrieval of a non-existing file from a session. + """ + # Set current_user to match session.user + mock_get_current_user.return_value = mock_session_get_by_id_patched.user + + # Make the GET request to retrieve the file + response = test_client.get("/project/file/test_file_id?session_id=test_session_id") + + # Verify the response + assert response.status_code == 404 + + +def test_get_file_non_existing_session( + test_client: TestClient, + tmp_path: pathlib.Path, + mock_session_get_by_id_patched: Mock, + mock_session: Mock, + monkeypatch: pytest.MonkeyPatch, +): + """ + Test that an unauthenticated user cannot retrieve a file uploaded by an authenticated user. + """ + + # Attempt to access the file without authentication by providing an invalid session_id + response = test_client.get( + f"/project/file/nonexistent?session_id=unauthenticated_session_id" + ) + + # Verify the response + assert response.status_code == 401 # Unauthorized + + +def test_upload_file_success( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, +): + """Test successful file upload.""" + + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + # Mock the persist_file method to return a known value + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={"session_id": mock_session_get_by_id_patched.id}, + ) + + # Verify the response + assert response.status_code == 200 + response_data = response.json() + assert "id" in response_data + assert response_data["id"] == expected_file_id + assert response_data["name"] == "test_upload.txt" + assert response_data["type"] == "text/plain" + assert response_data["size"] == len(file_content) + + # Verify that persist_file was called with the correct arguments + mock_session_get_by_id_patched.persist_file.assert_called_once_with( + name="test_upload.txt", content=file_content, mime="text/plain" + ) + + +def test_file_access_by_different_user( + test_client: TestClient, + mock_session_get_by_id_patched: Mock, + persisted_test_user: PersistedUser, + tmp_path: pathlib.Path, + mock_session_factory: Callable[..., Mock], +): + """Test that a file uploaded by one user cannot be accessed by another user.""" + + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + # Mock the persist_file method to return a known value + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={"session_id": mock_session_get_by_id_patched.id}, + ) + + # Verify the response + assert response.status_code == 200 + + response_data = response.json() + assert "id" in response_data + file_id = response_data["id"] + + # Create a second session with a different user + second_session = mock_session_factory( + id="another_session_id", + user=PersistedUser( + id="another_user_id", + createdAt=datetime.datetime.now().isoformat(), + identifier="another_user_identifier", + ), + ) + + # Attempt to access the uploaded file using the second user's session + response = test_client.get( + f"/project/file/{file_id}?session_id={second_session.id}" + ) + + # Verify that the access attempt fails + assert response.status_code == 401 # Unauthorized + + +def test_upload_file_missing_file( + test_client: TestClient, + mock_session: Mock, +): + """Test file upload with missing file in the request.""" + + # Make the POST request without a file + response = test_client.post( + "/project/file", + data={"session_id": mock_session.id}, + ) + + # Verify the response + assert response.status_code == 422 # Unprocessable Entity + assert "detail" in response.json() + + +def test_upload_file_invalid_session( + test_client: TestClient, +): + """Test file upload with an invalid session.""" + + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + # Make the POST request with an invalid session_id + response = test_client.post( + "/project/file", + files=files, + data={"session_id": "invalid_session_id"}, + ) + + # Verify the response + assert response.status_code == 422 + + +def test_upload_file_unauthorized( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, +): + """Test file upload without proper authorization.""" + + # Mock the upload_file_session to have no user + mock_session_get_by_id_patched.user = None + + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + data={"session_id": mock_session_get_by_id_patched.id}, + ) + + assert response.status_code == 422 + + def test_project_translations_file_path_traversal( test_client: TestClient, monkeypatch: pytest.MonkeyPatch ):