Skip to content

Commit

Permalink
Revert "Test and resolve security vulnerability with get_file and upl…
Browse files Browse the repository at this point in the history
…oad_file (Chainlit#1441)"

This reverts commit e65f191.
  • Loading branch information
wenboown committed Oct 20, 2024
1 parent b00427b commit e69e9cc
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 313 deletions.
24 changes: 7 additions & 17 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
APIRouter,
Depends,
FastAPI,
File,
Form,
HTTPException,
Query,
Expand Down Expand Up @@ -840,9 +839,11 @@ async def delete_thread(

@router.post("/project/file")
async def upload_file(
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
session_id: str,
file: UploadFile,
current_user: Annotated[
Union[None, User, PersistedUser], Depends(get_current_user)
],
):
"""Upload a file to the session files directory."""

Expand All @@ -867,21 +868,17 @@ async def upload_file(

content = await file.read()

assert file.filename, "No filename for uploaded file"
assert file.content_type, "No content type for uploaded file"

file_response = await session.persist_file(
name=file.filename, content=content, mime=file.content_type
)

return JSONResponse(content=file_response)
return JSONResponse(file_response)


@router.get("/project/file/{file_id}")
async def get_file(
file_id: str,
session_id: str,
current_user: Annotated[Union[User, PersistedUser], Depends(get_current_user)],
session_id: Optional[str] = None,
):
"""Get a file from the session files directory."""

Expand All @@ -891,17 +888,10 @@ async def get_file(

if not session:
raise HTTPException(
status_code=401,
detail="Unauthorized",
status_code=404,
detail="Session not found",
)

if current_user:
if not session.user or session.user.identifier != current_user.identifier:
raise HTTPException(
status_code=401,
detail="You are not authorized to download files from this session",
)

if file_id in session.files:
file = session.files[file_id]
return FileResponse(file["path"], media_type=file["type"])
Expand Down
39 changes: 14 additions & 25 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
from contextlib import asynccontextmanager
from typing import Callable
from unittest.mock import AsyncMock, Mock

import pytest
Expand All @@ -21,30 +20,20 @@ def persisted_test_user():


@pytest.fixture
def mock_session_factory(persisted_test_user: PersistedUser) -> Callable[..., Mock]:
def create_mock_session(**kwargs) -> Mock:
mock = Mock(spec=WebsocketSession)
mock.user = kwargs.get("user", persisted_test_user)
mock.id = kwargs.get("id", "test_session_id")
mock.user_env = kwargs.get("user_env", {"test_env": "value"})
mock.chat_settings = kwargs.get("chat_settings", {})
mock.chat_profile = kwargs.get("chat_profile", None)
mock.http_referer = kwargs.get("http_referer", None)
mock.client_type = kwargs.get("client_type", "webapp")
mock.languages = kwargs.get("languages", ["en"])
mock.thread_id = kwargs.get("thread_id", "test_thread_id")
mock.emit = AsyncMock()
mock.has_first_interaction = kwargs.get("has_first_interaction", True)
mock.files = kwargs.get("files", {})

return mock

return create_mock_session


@pytest.fixture
def mock_session(mock_session_factory) -> Mock:
return mock_session_factory()
def mock_session():
mock = Mock(spec=WebsocketSession)
mock.id = "test_session_id"
mock.user_env = {"test_env": "value"}
mock.chat_settings = {}
mock.chat_profile = None
mock.http_referer = None
mock.client_type = "webapp"
mock.languages = ["en"]
mock.thread_id = "test_thread_id"
mock.emit = AsyncMock()
mock.has_first_interaction = True

return mock


@asynccontextmanager
Expand Down
Loading

0 comments on commit e69e9cc

Please sign in to comment.