diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 7c4a824b68..b4bf0dd989 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -41,7 +41,6 @@ APIRouter, Depends, FastAPI, - File, Form, HTTPException, Query, @@ -840,9 +839,11 @@ async def delete_thread( @router.post("/project/file") async def upload_file( - current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], session_id: str, file: UploadFile, + current_user: Annotated[ + Union[None, User, PersistedUser], Depends(get_current_user) + ], ): """Upload a file to the session files directory.""" @@ -867,21 +868,17 @@ async def upload_file( content = await file.read() - assert file.filename, "No filename for uploaded file" - assert file.content_type, "No content type for uploaded file" - file_response = await session.persist_file( name=file.filename, content=content, mime=file.content_type ) - return JSONResponse(content=file_response) + return JSONResponse(file_response) @router.get("/project/file/{file_id}") async def get_file( file_id: str, - session_id: str, - current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)], + session_id: Optional[str] = None, ): """Get a file from the session files directory.""" @@ -891,17 +888,10 @@ async def get_file( if not session: raise HTTPException( - status_code=401, - detail="Unauthorized", + status_code=404, + detail="Session not found", ) - if current_user: - if not session.user or session.user.identifier != current_user.identifier: - raise HTTPException( - status_code=401, - detail="You are not authorized to download files from this session", - ) - if file_id in session.files: file = session.files[file_id] return FileResponse(file["path"], media_type=file["type"]) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 7cc4266371..94ae4a596a 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,6 +1,5 @@ import datetime from contextlib import asynccontextmanager -from typing import Callable from unittest.mock import AsyncMock, Mock import pytest @@ -21,30 +20,20 @@ def persisted_test_user(): @pytest.fixture -def mock_session_factory(persisted_test_user: PersistedUser) -> Callable[..., Mock]: - def create_mock_session(**kwargs) -> Mock: - mock = Mock(spec=WebsocketSession) - mock.user = kwargs.get("user", persisted_test_user) - mock.id = kwargs.get("id", "test_session_id") - mock.user_env = kwargs.get("user_env", {"test_env": "value"}) - mock.chat_settings = kwargs.get("chat_settings", {}) - mock.chat_profile = kwargs.get("chat_profile", None) - mock.http_referer = kwargs.get("http_referer", None) - mock.client_type = kwargs.get("client_type", "webapp") - mock.languages = kwargs.get("languages", ["en"]) - mock.thread_id = kwargs.get("thread_id", "test_thread_id") - mock.emit = AsyncMock() - mock.has_first_interaction = kwargs.get("has_first_interaction", True) - mock.files = kwargs.get("files", {}) - - return mock - - return create_mock_session - - -@pytest.fixture -def mock_session(mock_session_factory) -> Mock: - return mock_session_factory() +def mock_session(): + mock = Mock(spec=WebsocketSession) + mock.id = "test_session_id" + mock.user_env = {"test_env": "value"} + mock.chat_settings = {} + mock.chat_profile = None + mock.http_referer = None + mock.client_type = "webapp" + mock.languages = ["en"] + mock.thread_id = "test_thread_id" + mock.emit = AsyncMock() + mock.has_first_interaction = True + + return mock @asynccontextmanager diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index 36c65124d6..6aa9c16d5a 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -1,19 +1,12 @@ import os from pathlib import Path -import pathlib -from typing import Callable -from unittest.mock import AsyncMock, Mock, create_autospec, mock_open -import datetime # Added import for datetime +from unittest.mock import Mock, create_autospec, mock_open import pytest -import tempfile -from chainlit.session import WebsocketSession from chainlit.auth import get_current_user from chainlit.config import APP_ROOT, ChainlitConfig, load_config from chainlit.server import app from fastapi.testclient import TestClient -from chainlit.types import FileReference -from chainlit.user import PersistedUser # Added import for PersistedUser @pytest.fixture @@ -226,7 +219,7 @@ def test_get_avatar_non_existent_favicon( def test_avatar_path_traversal( - test_client: TestClient, monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path + test_client: TestClient, monkeypatch: pytest.MonkeyPatch, tmp_path ): """Test to prevent potential path traversal in avatar route on Windows.""" @@ -247,268 +240,6 @@ def test_avatar_path_traversal( assert response.status_code == 400 -@pytest.fixture -def mock_session_get_by_id_patched(mock_session: Mock, monkeypatch: pytest.MonkeyPatch): - test_session_id = "test_session_id" - - # Mock the WebsocketSession.get_by_id method to return the mock session - monkeypatch.setattr( - "chainlit.session.WebsocketSession.get_by_id", - lambda session_id: mock_session if session_id == test_session_id else None, - ) - - return mock_session - - -def test_get_file_success( - test_client: TestClient, - mock_session_get_by_id_patched: Mock, - tmp_path: pathlib.Path, - mock_get_current_user: Mock, -): - """ - Test successful retrieval of a file from a session. - """ - # Set current_user to match session.user - mock_get_current_user.return_value = mock_session_get_by_id_patched.user - - # Create test data - test_content = b"Test file content" - test_file_id = "test_file_id" - - # Create a temporary file with the test content - test_file = tmp_path / "test_file" - test_file.write_bytes(test_content) - - mock_session_get_by_id_patched.files = { - test_file_id: { - "id": test_file_id, - "path": test_file, - "name": "test.txt", - "type": "text/plain", - "size": len(test_content), - } - } - - # Make the GET request to retrieve the file - response = test_client.get( - f"/project/file/{test_file_id}?session_id={mock_session_get_by_id_patched.id}" - ) - - # Verify the response - assert response.status_code == 200 - assert response.content == test_content - assert response.headers["content-type"].startswith("text/plain") - - -def test_get_file_not_existent_file( - test_client: TestClient, - mock_session_get_by_id_patched: Mock, - mock_get_current_user: Mock, -): - """ - Test retrieval of a non-existing file from a session. - """ - # Set current_user to match session.user - mock_get_current_user.return_value = mock_session_get_by_id_patched.user - - # Make the GET request to retrieve the file - response = test_client.get("/project/file/test_file_id?session_id=test_session_id") - - # Verify the response - assert response.status_code == 404 - - -def test_get_file_non_existing_session( - test_client: TestClient, - tmp_path: pathlib.Path, - mock_session_get_by_id_patched: Mock, - mock_session: Mock, - monkeypatch: pytest.MonkeyPatch, -): - """ - Test that an unauthenticated user cannot retrieve a file uploaded by an authenticated user. - """ - - # Attempt to access the file without authentication by providing an invalid session_id - response = test_client.get( - f"/project/file/nonexistent?session_id=unauthenticated_session_id" - ) - - # Verify the response - assert response.status_code == 401 # Unauthorized - - -def test_upload_file_success( - test_client: TestClient, - test_config: ChainlitConfig, - mock_session_get_by_id_patched: Mock, -): - """Test successful file upload.""" - - # Prepare the files to upload - file_content = b"Sample file content" - files = { - "file": ("test_upload.txt", file_content, "text/plain"), - } - - # Mock the persist_file method to return a known value - expected_file_id = "mocked_file_id" - mock_session_get_by_id_patched.persist_file = AsyncMock( - return_value={ - "id": expected_file_id, - "name": "test_upload.txt", - "type": "text/plain", - "size": len(file_content), - } - ) - - # Make the POST request to upload the file - response = test_client.post( - "/project/file", - files=files, - params={"session_id": mock_session_get_by_id_patched.id}, - ) - - # Verify the response - assert response.status_code == 200 - response_data = response.json() - assert "id" in response_data - assert response_data["id"] == expected_file_id - assert response_data["name"] == "test_upload.txt" - assert response_data["type"] == "text/plain" - assert response_data["size"] == len(file_content) - - # Verify that persist_file was called with the correct arguments - mock_session_get_by_id_patched.persist_file.assert_called_once_with( - name="test_upload.txt", content=file_content, mime="text/plain" - ) - - -def test_file_access_by_different_user( - test_client: TestClient, - mock_session_get_by_id_patched: Mock, - persisted_test_user: PersistedUser, - tmp_path: pathlib.Path, - mock_session_factory: Callable[..., Mock], -): - """Test that a file uploaded by one user cannot be accessed by another user.""" - - # Prepare the files to upload - file_content = b"Sample file content" - files = { - "file": ("test_upload.txt", file_content, "text/plain"), - } - - # Mock the persist_file method to return a known value - expected_file_id = "mocked_file_id" - mock_session_get_by_id_patched.persist_file = AsyncMock( - return_value={ - "id": expected_file_id, - "name": "test_upload.txt", - "type": "text/plain", - "size": len(file_content), - } - ) - - # Make the POST request to upload the file - response = test_client.post( - "/project/file", - files=files, - params={"session_id": mock_session_get_by_id_patched.id}, - ) - - # Verify the response - assert response.status_code == 200 - - response_data = response.json() - assert "id" in response_data - file_id = response_data["id"] - - # Create a second session with a different user - second_session = mock_session_factory( - id="another_session_id", - user=PersistedUser( - id="another_user_id", - createdAt=datetime.datetime.now().isoformat(), - identifier="another_user_identifier", - ), - ) - - # Attempt to access the uploaded file using the second user's session - response = test_client.get( - f"/project/file/{file_id}?session_id={second_session.id}" - ) - - # Verify that the access attempt fails - assert response.status_code == 401 # Unauthorized - - -def test_upload_file_missing_file( - test_client: TestClient, - mock_session: Mock, -): - """Test file upload with missing file in the request.""" - - # Make the POST request without a file - response = test_client.post( - "/project/file", - data={"session_id": mock_session.id}, - ) - - # Verify the response - assert response.status_code == 422 # Unprocessable Entity - assert "detail" in response.json() - - -def test_upload_file_invalid_session( - test_client: TestClient, -): - """Test file upload with an invalid session.""" - - # Prepare the files to upload - file_content = b"Sample file content" - files = { - "file": ("test_upload.txt", file_content, "text/plain"), - } - - # Make the POST request with an invalid session_id - response = test_client.post( - "/project/file", - files=files, - data={"session_id": "invalid_session_id"}, - ) - - # Verify the response - assert response.status_code == 422 - - -def test_upload_file_unauthorized( - test_client: TestClient, - test_config: ChainlitConfig, - mock_session_get_by_id_patched: Mock, -): - """Test file upload without proper authorization.""" - - # Mock the upload_file_session to have no user - mock_session_get_by_id_patched.user = None - - # Prepare the files to upload - file_content = b"Sample file content" - files = { - "file": ("test_upload.txt", file_content, "text/plain"), - } - - # Make the POST request to upload the file - response = test_client.post( - "/project/file", - files=files, - data={"session_id": mock_session_get_by_id_patched.id}, - ) - - assert response.status_code == 422 - - def test_project_translations_file_path_traversal( test_client: TestClient, monkeypatch: pytest.MonkeyPatch ):