Skip to content

Commit

Permalink
Merge branch 'graphql-transport-ws'
Browse files Browse the repository at this point in the history
  • Loading branch information
SmileyChris committed Dec 16, 2021
2 parents 7ef25ec + 2f0eb21 commit dc86b24
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 38 deletions.
24 changes: 21 additions & 3 deletions graphql_ws/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from graphql import format_error, graphql

from .constants import (
GQL_COMPLETE,
GQL_CONNECTION_ERROR,
GQL_CONNECTION_INIT,
GQL_CONNECTION_TERMINATE,
GQL_DATA,
GQL_ERROR,
GQL_NEXT,
GQL_START,
GQL_STOP,
GQL_SUBSCRIBE,
TRANSPORT_WS_PROTOCOL,
)


Expand All @@ -19,10 +23,15 @@ class ConnectionClosedException(Exception):


class BaseConnectionContext(object):
transport_ws_protocol = False

def __init__(self, ws, request_context=None):
self.ws = ws
self.operations = {}
self.request_context = request_context
self.transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in (
request_context.get("subprotocols") or []
)

def has_operation(self, op_id):
return op_id in self.operations
Expand Down Expand Up @@ -84,12 +93,16 @@ def process_message(self, connection_context, parsed_message):
elif op_type == GQL_CONNECTION_TERMINATE:
return self.on_connection_terminate(connection_context, op_id)

elif op_type == GQL_START:
elif op_type == (
GQL_SUBSCRIBE if connection_context.transport_ws_protocol else GQL_START
):
assert isinstance(payload, dict), "The payload must be a dict"
params = self.get_graphql_params(connection_context, payload)
return self.on_start(connection_context, op_id, params)

elif op_type == GQL_STOP:
elif op_type == (
GQL_COMPLETE if connection_context.transport_ws_protocol else GQL_STOP
):
return self.on_stop(connection_context, op_id)

else:
Expand Down Expand Up @@ -142,7 +155,12 @@ def build_message(self, id, op_type, payload):

def send_execution_result(self, connection_context, op_id, execution_result):
result = self.execution_result_to_dict(execution_result)
return self.send_message(connection_context, op_id, GQL_DATA, result)
return self.send_message(
connection_context,
op_id,
GQL_NEXT if connection_context.transport_ws_protocol else GQL_DATA,
result,
)

def execution_result_to_dict(self, execution_result):
result = OrderedDict()
Expand Down
10 changes: 7 additions & 3 deletions graphql_ws/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
GRAPHQL_WS = "graphql-ws"
WS_PROTOCOL = GRAPHQL_WS
TRANSPORT_WS_PROTOCOL = "graphql-transport-ws"

GQL_CONNECTION_INIT = "connection_init" # Client -> Server
GQL_CONNECTION_ACK = "connection_ack" # Server -> Client
Expand All @@ -8,8 +9,11 @@
# NOTE: This one here don't follow the standard due to connection optimization
GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server
GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client
GQL_START = "start" # Client -> Server
GQL_DATA = "data" # Server -> Client
GQL_START = "start" # Client -> Server (graphql-ws)
GQL_SUBSCRIBE = "subscribe" # Client -> Server (graphql-transport-ws START equivalent)
GQL_DATA = "data" # Server -> Client (graphql-ws)
GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent)
GQL_ERROR = "error" # Server -> Client
GQL_COMPLETE = "complete" # Server -> Client
GQL_STOP = "stop" # Client -> Server
# (and Client -> Server for graphql-transport-ws STOP equivalent)
GQL_STOP = "stop" # Client -> Server (graphql-ws only)
19 changes: 12 additions & 7 deletions graphql_ws/django/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@

from channels.generic.websocket import AsyncJsonWebsocketConsumer

from ..constants import WS_PROTOCOL
from ..constants import TRANSPORT_WS_PROTOCOL, WS_PROTOCOL
from .subscriptions import subscription_server


class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer):
async def connect(self):
self.connection_context = None
if WS_PROTOCOL in self.scope["subprotocols"]:
self.connection_context = await subscription_server.handle(
ws=self, request_context=self.scope
)
await self.accept(subprotocol=WS_PROTOCOL)
else:
found_protocol = None
for protocol in [WS_PROTOCOL, TRANSPORT_WS_PROTOCOL]:
if protocol in self.scope["subprotocols"]:
found_protocol = protocol
break
if not found_protocol:
await self.close()
return
self.connection_context = await subscription_server.handle(
ws=self, request_context=self.scope
)
await self.accept(subprotocol=found_protocol)

async def disconnect(self, code):
if self.connection_context:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ test =
graphene>=2.0,<3
gevent
graphene_django
mock; python_version<"3"
mock; python_version<"3.8"
django==1.11.*; python_version<"3"
channels==1.*; python_version<"3"
django==3.*; python_version>="3"
Expand Down
40 changes: 22 additions & 18 deletions tests/test_base_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
pytestmark = pytest.mark.asyncio


class AsyncMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
try:
from unittest.mock import AsyncMock # Python 3.8+
except ImportError:
from mock import AsyncMock


class TstServer(base_async.BaseAsyncSubscriptionServer):
Expand All @@ -26,75 +27,78 @@ def server():


async def test_terminate(server: TstServer):
context = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
await server.on_connection_terminate(connection_context=context, op_id=1)
context.close.assert_called_with(1011)


async def test_send_error(server: TstServer):
context = AsyncMock()
context.has_operation = mock.Mock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
await server.send_error(connection_context=context, op_id=1, error="test error")
context.send.assert_called_with(
{"id": 1, "type": "error", "payload": {"message": "test error"}}
)


async def test_message(server):
async def test_message(server: TstServer):
server.process_message = AsyncMock()
context = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
await server.on_message(context, msg)
server.process_message.assert_called_with(context, msg)


async def test_message_str(server):
async def test_message_str(server: TstServer):
server.process_message = AsyncMock()
context = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
await server.on_message(context, json.dumps(msg))
server.process_message.assert_called_with(context, msg)


async def test_message_invalid(server):
async def test_message_invalid(server: TstServer):
server.send_error = AsyncMock()
await server.on_message(connection_context=None, message="'not-json")
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
await server.on_message(context, message="'not-json")
assert server.send_error.called


async def test_resolver(server):
async def test_resolver(server: TstServer):
server.send_message = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
result = mock.Mock()
result.data = {"test": [1, 2]}
result.errors = None
await server.send_execution_result(
connection_context=None, op_id=1, execution_result=result
context, op_id=1, execution_result=result
)
assert server.send_message.called


@pytest.mark.asyncio
async def test_resolver_with_promise(server):
async def test_resolver_with_promise(server: TstServer):
server.send_message = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
result = mock.Mock()
result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]}
result.errors = None
await server.send_execution_result(
connection_context=None, op_id=1, execution_result=result
context, op_id=1, execution_result=result
)
assert server.send_message.called
assert result.data == {"test": [1, 2]}


async def test_resolver_with_nested_promise(server):
async def test_resolver_with_nested_promise(server: TstServer):
server.send_message = AsyncMock()
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
result = mock.Mock()
inner = promise.Promise(lambda resolve, reject: resolve(2))
outer = promise.Promise(lambda resolve, reject: resolve({"in": inner}))
result.data = {"test": [1, outer]}
result.errors = None
await server.send_execution_result(
connection_context=None, op_id=1, execution_result=result
context, op_id=1, execution_result=result
)
assert server.send_message.called
assert result.data == {"test": [1, {"in": 2}]}
45 changes: 39 additions & 6 deletions tests/test_graphql_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,20 @@ def test_terminate(self, ss, cc):
ss.process_message(cc, {"id": "1", "type": constants.GQL_CONNECTION_TERMINATE})
ss.on_connection_terminate.assert_called_with(cc, "1")

def test_start(self, ss, cc):
@pytest.mark.parametrize(
"transport_ws_protocol,expected_type",
((False, constants.GQL_START), (True, constants.GQL_SUBSCRIBE)),
)
def test_start(self, ss, cc, transport_ws_protocol, expected_type):
ss.get_graphql_params = mock.Mock()
ss.get_graphql_params.return_value = {"params": True}
cc.has_operation = mock.Mock()
cc.has_operation.return_value = False
cc.transport_ws_protocol = transport_ws_protocol
ss.unsubscribe = mock.Mock()
ss.on_start = mock.Mock()
ss.process_message(
cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}}
cc, {"id": "1", "type": expected_type, "payload": {"a": "b"}}
)
assert not ss.unsubscribe.called
ss.on_start.assert_called_with(cc, "1", {"params": True})
Expand Down Expand Up @@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc):
assert isinstance(ss.send_error.call_args[0][2], Exception)
assert not ss.on_start.called

def test_stop(self, ss, cc):
@pytest.mark.parametrize(
"transport_ws_protocol,stop_type,invalid_stop_type",
(
(False, constants.GQL_STOP, constants.GQL_COMPLETE),
(True, constants.GQL_COMPLETE, constants.GQL_STOP),
),
)
def test_stop(
self,
ss,
cc,
transport_ws_protocol,
stop_type,
invalid_stop_type,
):
ss.on_stop = mock.Mock()
ss.process_message(cc, {"id": "1", "type": constants.GQL_STOP})
ss.send_error = mock.Mock()
cc.transport_ws_protocol = transport_ws_protocol

ss.process_message(cc, {"id": "1", "type": invalid_stop_type})
assert ss.send_error.called
assert ss.send_error.call_args[0][:2] == (cc, "1")
assert isinstance(ss.send_error.call_args[0][2], Exception)
assert not ss.on_stop.called

ss.process_message(cc, {"id": "1", "type": stop_type})
ss.on_stop.assert_called_with(cc, "1")

def test_invalid(self, ss, cc):
Expand Down Expand Up @@ -165,13 +193,18 @@ def test_build_message_partial(ss):
ss.build_message(id=None, op_type=None, payload=None)


def test_send_execution_result(ss):
@pytest.mark.parametrize(
"transport_ws_protocol,expected_type",
((False, constants.GQL_DATA), (True, constants.GQL_NEXT)),
)
def test_send_execution_result(ss, cc, transport_ws_protocol, expected_type):
cc.transport_ws_protocol = transport_ws_protocol
ss.execution_result_to_dict = mock.Mock()
ss.execution_result_to_dict.return_value = {"res": "ult"}
ss.send_message = mock.Mock()
ss.send_message.return_value = "returned"
assert "returned" == ss.send_execution_result(cc, "1", "result")
ss.send_message.assert_called_with(cc, "1", constants.GQL_DATA, {"res": "ult"})
ss.send_message.assert_called_with(cc, "1", expected_type, {"res": "ult"})


def test_execution_result_to_dict(ss):
Expand Down

0 comments on commit dc86b24

Please sign in to comment.