Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added support for WebSocket in TestClient
Browse files Browse the repository at this point in the history
Randomneo committed Mar 15, 2024
1 parent 026897f commit c363263
Showing 4 changed files with 144 additions and 45 deletions.
21 changes: 21 additions & 0 deletions blacksheep/testing/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from typing import Optional

from blacksheep.contents import Content
from blacksheep.server.application import Application
from blacksheep.server.responses import Response
from blacksheep.testing.simulator import AbstractTestSimulator, TestSimulator
from blacksheep.testing.websocket import TestWebSocket

from .helpers import CookiesType, HeadersType, QueryType

@@ -157,3 +159,22 @@ async def trace(
content=None,
cookies=cookies,
)

def websocket_connect(
self,
path: str,
headers: HeadersType = None,
query: QueryType = None,
cookies: CookiesType = None,
) -> TestWebSocket:
return self._test_simulator.websocket_connect(
path=path,
headers=headers,
query=query,
content=None,
cookies=cookies,
)

async def websocket_all_closed(self):
await asyncio.gather(*self._test_simulator.websocket_tasks)
self._test_simulator.websocket_tasks = []
38 changes: 38 additions & 0 deletions blacksheep/testing/simulator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import abc
import asyncio
from typing import Dict, Optional

from blacksheep.contents import Content
from blacksheep.messages import Request
from blacksheep.server.application import Application
from blacksheep.server.responses import Response
from blacksheep.testing.helpers import get_example_scope
from blacksheep.testing.websocket import TestWebSocket

from .helpers import CookiesType, HeadersType, QueryType

@@ -47,6 +49,17 @@ async def send_request(
Then you can define an own TestClient, with the custom logic.
"""

@abc.abstractmethod
async def websocket_connect(
self,
path,
headers: HeadersType = None,
query: QueryType = None,
content: Optional[Content] = None,
cookies: CookiesType = None,
) -> TestWebSocket:
"""Entrypoint for WebSocket"""


class TestSimulator(AbstractTestSimulator):
"""Base Test simulator class
@@ -57,6 +70,7 @@ class TestSimulator(AbstractTestSimulator):

def __init__(self, app: Application):
self.app = app
self.websocket_tasks = []
self._is_started_app()

async def send_request(
@@ -90,6 +104,30 @@ async def send_request(

return response

def websocket_connect(
self,
path: str,
headers: HeadersType = None,
query: QueryType = None,
content: Optional[Content] = None,
cookies: CookiesType = None,
) -> TestWebSocket:
scope = _create_scope("GET_WS", path, headers, query, cookies=cookies)
scope["type"] = "websocket"
test_websocket = TestWebSocket()

self.websocket_tasks.append(
asyncio.create_task(
self.app(
scope,
test_websocket._receive,
test_websocket._send,
),
),
)

return test_websocket

def _is_started_app(self):
assert (
self.app.started
37 changes: 37 additions & 0 deletions blacksheep/testing/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

import asyncio
from typing import Any


class TestWebSocket:
def __init__(self):
self.send_queue = asyncio.Queue()
self.receive_queue = asyncio.Queue()

async def _send(self, data: Any) -> None:
await self.send_queue.put(data)

async def _receive(self) -> Any:
return await self.receive_queue.get()

async def send(self, data: Any) -> None:
await self.receive_queue.put(data)

async def receive(self) -> Any:
return await self.send_queue.get()

async def __aenter__(self) -> TestWebSocket:
await self.send({"type": "websocket.connect"})
received = await self.receive()
assert received.get("type") == "websocket.accept"
return self

async def __aexit__(self, exc_type, exc_value, exc_tb) -> None:
await self.send(
{
"type": "websocket.disconnect",
"code": 1000,
"reason": "TestWebSocket context closed",
},
)
93 changes: 48 additions & 45 deletions tests/test_websocket.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
WebSocketState,
format_reason,
)
from blacksheep.testing import TestClient
from blacksheep.testing.messages import MockReceive, MockSend
from tests.utils.application import FakeApplication

@@ -322,18 +323,19 @@ async def test_websocket_raises_for_receive_when_closed_by_client(example_scope)


@pytest.mark.asyncio
async def test_application_handling_websocket_request_not_found(example_scope):
async def test_application_handling_websocket_request_not_found():
"""
If a client tries to open a WebSocket connection on an endpoint that is not handled,
the application returns an ASGI message to close the connection.
"""
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive()
await app.start()

await app(example_scope, mock_receive, mock_send)
client = TestClient(app)
test_websocket = client.websocket_connect("/ws")
await test_websocket.send({"type": "websocket.connect"})
close_message = await test_websocket.receive()

close_message = mock_send.messages[0]
assert close_message == {"type": "websocket.close", "reason": None, "code": 1000}


@@ -344,8 +346,6 @@ async def test_application_handling_proper_websocket_request():
the application websocket handler is called.
"""
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

@app.router.ws("/ws/{foo}")
async def websocket_handler(websocket, foo):
@@ -358,21 +358,17 @@ async def websocket_handler(websocket, foo):
await websocket.accept()

await app.start()
await app(
{"type": "websocket", "path": "/ws/001", "query_string": "", "headers": []},
mock_receive,
mock_send,
)
client = TestClient(app)
async with client.websocket_connect("/ws/001"):
pass


@pytest.mark.asyncio
async def test_application_handling_proper_websocket_request_with_query():
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

@app.router.ws("/ws/{foo}")
async def websocket_handler(websocket, foo, from_query: int):
async def websocket_handler(websocket: WebSocket, foo, from_query: int):
assert isinstance(websocket, WebSocket)
assert websocket.application_state == WebSocketState.CONNECTING
assert websocket.client_state == WebSocketState.CONNECTING
@@ -383,41 +379,27 @@ async def websocket_handler(websocket, foo, from_query: int):
await websocket.accept()

await app.start()
await app(
{
"type": "websocket",
"path": "/ws/001",
"query_string": b"from_query=200",
"headers": [],
},
mock_receive,
mock_send,
)
client = TestClient(app)
async with client.websocket_connect("/ws/001", query="from_query=200"):
pass


@pytest.mark.asyncio
async def test_application_handling_proper_websocket_request_header_binding(
example_scope,
):
async def test_application_handling_proper_websocket_request_header_binding():
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

class UpgradeHeader(FromHeader[str]):
name = "Upgrade"

called = False

@app.router.ws("/ws")
async def websocket_handler(connect_header: UpgradeHeader):
async def websocket_handler(websocket: WebSocket, connect_header: UpgradeHeader):
assert connect_header.value == "websocket"

nonlocal called
called = True
await websocket.accept()

await app.start()
await app(example_scope, mock_receive, mock_send)
assert called is True
client = TestClient(app)
async with client.websocket_connect("/ws", headers={"upgrade": "websocket"}):
pass


@pytest.mark.asyncio
@@ -426,8 +408,6 @@ async def test_application_websocket_binding_by_type_annotation():
This test verifies that the WebSocketBinder can bind a WebSocket by type annotation.
"""
app = FakeApplication()
mock_send = MockSend()
mock_receive = MockReceive([{"type": "websocket.connect"}])

@app.router.ws("/ws")
async def websocket_handler(my_ws: WebSocket):
@@ -438,11 +418,9 @@ async def websocket_handler(my_ws: WebSocket):
await my_ws.accept()

await app.start()
await app(
{"type": "websocket", "path": "/ws", "query_string": "", "headers": []},
mock_receive,
mock_send,
)
client = TestClient(app)
async with client.websocket_connect("/ws"):
pass


@pytest.mark.asyncio
@@ -466,6 +444,31 @@ async def websocket_handler(my_ws: WebSocket):
assert await route.handler(...) is None


@pytest.mark.asyncio
async def test_testwebsocket_closing():
"""
This test verifies that websocket.disconnect is sent by TestWebSocket
"""
app = FakeApplication()
disconnected = False

@app.router.ws("/ws")
async def websocket_handler(my_ws: WebSocket):
await my_ws.accept()
try:
await my_ws.receive()
except WebSocketDisconnectError:
nonlocal disconnected
disconnected = True

await app.start()
client = TestClient(app)
async with client.websocket_connect("/ws"):
pass
await client.websocket_all_closed()
assert disconnected is True


LONG_REASON = "WRY" * 41
QIN = "秦" # Qyn dynasty in Chinese, 3 bytes.
TOO_LONG_REASON = QIN * 42

0 comments on commit c363263

Please sign in to comment.