Skip to content

Commit

Permalink
Refine tests and resolve related security issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
dokterbob committed Oct 17, 2024
1 parent 07fc6b1 commit 7f79614
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 64 deletions.
18 changes: 9 additions & 9 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
APIRouter,
Depends,
FastAPI,
File,
Form,
HTTPException,
Query,
Expand Down Expand Up @@ -839,11 +840,9 @@ async def delete_thread(

@router.post("/project/file")
async def upload_file(
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
session_id: str,
file: UploadFile,
current_user: Annotated[
Union[None, User, PersistedUser], Depends(get_current_user)
],
):
"""Upload a file to the session files directory."""

Expand All @@ -868,20 +867,21 @@ async def upload_file(

content = await file.read()

assert file.filename, "No filename for uploaded file"
assert file.content_type, "No content type for uploaded file"

file_response = await session.persist_file(
name=file.filename, content=content, mime=file.content_type
)

return JSONResponse(file_response)
return JSONResponse(content=file_response)


@router.get("/project/file/{file_id}")
async def get_file(
file_id: str,
session_id: str,
current_user: Annotated[
Union[None, User, PersistedUser], Depends(get_current_user)
],
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
):
"""Get a file from the session files directory."""

Expand All @@ -891,8 +891,8 @@ async def get_file(

if not session:
raise HTTPException(
status_code=404,
detail="Session not found",
status_code=401,
detail="Unauthorized",
)

if current_user:
Expand Down
40 changes: 25 additions & 15 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
from contextlib import asynccontextmanager
from typing import Callable
from unittest.mock import AsyncMock, Mock

import pytest
Expand All @@ -20,21 +21,30 @@ def persisted_test_user():


@pytest.fixture
def mock_session(persisted_test_user: PersistedUser):
mock = Mock(spec=WebsocketSession)
mock.user = persisted_test_user
mock.id = "test_session_id"
mock.user_env = {"test_env": "value"}
mock.chat_settings = {}
mock.chat_profile = None
mock.http_referer = None
mock.client_type = "webapp"
mock.languages = ["en"]
mock.thread_id = "test_thread_id"
mock.emit = AsyncMock()
mock.has_first_interaction = True

return mock
def mock_session_factory(persisted_test_user: PersistedUser) -> Callable[..., Mock]:
def create_mock_session(**kwargs) -> Mock:
mock = Mock(spec=WebsocketSession)
mock.user = kwargs.get("user", persisted_test_user)
mock.id = kwargs.get("id", "test_session_id")
mock.user_env = kwargs.get("user_env", {"test_env": "value"})
mock.chat_settings = kwargs.get("chat_settings", {})
mock.chat_profile = kwargs.get("chat_profile", None)
mock.http_referer = kwargs.get("http_referer", None)
mock.client_type = kwargs.get("client_type", "webapp")
mock.languages = kwargs.get("languages", ["en"])
mock.thread_id = kwargs.get("thread_id", "test_thread_id")
mock.emit = AsyncMock()
mock.has_first_interaction = kwargs.get("has_first_interaction", True)
mock.files = kwargs.get("files", {})

return mock

return create_mock_session


@pytest.fixture
def mock_session(mock_session_factory) -> Mock:
return mock_session_factory()


@asynccontextmanager
Expand Down
92 changes: 52 additions & 40 deletions backend/tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from pathlib import Path
import pathlib
from unittest.mock import Mock, create_autospec, mock_open
from typing import Callable
from unittest.mock import AsyncMock, Mock, create_autospec, mock_open
import datetime # Added import for datetime

import pytest
Expand All @@ -11,6 +12,7 @@
from chainlit.config import APP_ROOT, ChainlitConfig, load_config
from chainlit.server import app
from fastapi.testclient import TestClient
from chainlit.types import FileReference
from chainlit.user import PersistedUser # Added import for PersistedUser


Expand Down Expand Up @@ -296,26 +298,22 @@ def test_get_file_success(
assert response.headers["content-type"].startswith("text/plain")


def test_get_file_not_existing(
test_client: TestClient, monkeypatch: pytest.MonkeyPatch
def test_get_file_not_existent_file(
test_client: TestClient,
mock_session_get_by_id_patched: Mock,
):
"""
Test retrieval of a non-existing file from a session.
"""

# Mock the WebsocketSession.get_by_id method to return None
monkeypatch.setattr(
"chainlit.session.WebsocketSession.get_by_id", lambda session_id: None
)

# Make the GET request to retrieve the file
response = test_client.get("/project/file/test_file_id?session_id=test_session_id")

# Verify the response
assert response.status_code == 404


def test_get_file_unauthorized(
def test_get_file_non_existing_session(
test_client: TestClient,
tmp_path: pathlib.Path,
mock_session_get_by_id_patched: Mock,
Expand Down Expand Up @@ -348,32 +346,45 @@ def test_upload_file_success(
"file": ("test_upload.txt", file_content, "text/plain"),
}

# Mock the persist_file method to return a known value
expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
data={"session_id": mock_session_get_by_id_patched.id},
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == 200
response_data = response.json()
assert "id" in response_data
file_id = response_data["id"]

# Verify that the file is stored in the session
assert file_id in mock_session_get_by_id_patched.files
uploaded_file = mock_session_get_by_id_patched.files[file_id]
assert uploaded_file["name"] == "test_upload.txt"
assert uploaded_file["type"] == "text/plain"
assert uploaded_file["size"] == len(file_content)
assert response_data["id"] == expected_file_id
assert response_data["name"] == "test_upload.txt"
assert response_data["type"] == "text/plain"
assert response_data["size"] == len(file_content)

# Verify that persist_file was called with the correct arguments
mock_session_get_by_id_patched.persist_file.assert_called_once_with(
name="test_upload.txt", content=file_content, mime="text/plain"
)


def test_file_access_by_different_user(
test_client: TestClient,
mock_session_get_by_id_patched: Mock,
persisted_test_user: PersistedUser,
tmp_path: pathlib.Path,
mock_session_factory: Callable[..., Mock],
):
"""Test that a file uploaded by one user cannot be accessed by another user."""

Expand All @@ -383,30 +394,44 @@ def test_file_access_by_different_user(
"file": ("test_upload.txt", file_content, "text/plain"),
}

# Mock the persist_file method to return a known value
expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
data={"session_id": mock_session_get_by_id_patched.id},
params={"session_id": mock_session_get_by_id_patched.id},
)

# Verify the response
assert response.status_code == 200

response_data = response.json()
assert "id" in response_data
file_id = response_data["id"]

# Create a second mock session with a different user
mock_session_get_by_id_patched.user = PersistedUser(
id="another_user_id",
createdAt=datetime.datetime.now().isoformat(),
identifier="another_user_identifier",
# Create a second session with a different user
second_session = mock_session_factory(
id="another_session_id",
user=PersistedUser(
id="another_user_id",
createdAt=datetime.datetime.now().isoformat(),
identifier="another_user_identifier",
),
)
mock_session_get_by_id_patched.id = "another_session_id"

# Attempt to access the uploaded file using the second user's session
response = test_client.get(
f"/project/file/{file_id}?session_id={mock_session_get_by_id_patched.id}"
f"/project/file/{file_id}?session_id={second_session.id}"
)

# Verify that the access attempt fails
Expand Down Expand Up @@ -475,20 +500,7 @@ def test_upload_file_unauthorized(
data={"session_id": mock_session_get_by_id_patched.id},
)

# Assuming that unauthorized upload is allowed if user is not required
# If authorization is required, adjust the expected status code and response accordingly

assert response.status_code == 200 # Change this if authorization is enforced
response_data = response.json()
assert "id" in response_data
file_id = response_data["id"]

# Verify that the file is stored in the session
assert file_id in mock_session_get_by_id_patched.files
uploaded_file = mock_session_get_by_id_patched.files[file_id]
assert uploaded_file["name"] == "test_upload.txt"
assert uploaded_file["type"] == "text/plain"
assert uploaded_file["size"] == len(file_content)
assert response.status_code == 422


def test_project_translations_file_path_traversal(
Expand Down

0 comments on commit 7f79614

Please sign in to comment.