From cdda9d4c6a57db88392959cd0dab5de47651558f Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Sat, 1 May 2021 12:31:44 +1200 Subject: [PATCH 1/6] Support graphql-transport-ws websocket subprotocol --- graphql_ws/base.py | 24 ++++++++++++++++++++---- graphql_ws/constants.py | 11 +++++++---- graphql_ws/django/consumers.py | 20 ++++++++++++-------- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 31ad657..7bd2f10 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -4,13 +4,17 @@ from graphql import format_error, graphql from .constants import ( + GQL_COMPLETE, GQL_CONNECTION_ERROR, GQL_CONNECTION_INIT, GQL_CONNECTION_TERMINATE, GQL_DATA, GQL_ERROR, + GQL_NEXT, GQL_START, GQL_STOP, + GQL_SUBSCRIBE, + TRANSPORT_WS_PROTOCOL, ) @@ -23,6 +27,9 @@ def __init__(self, ws, request_context=None): self.ws = ws self.operations = {} self.request_context = request_context + self.transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in ( + request_context.get("subprotocols") or [] + ) def has_operation(self, op_id): return op_id in self.operations @@ -41,7 +48,7 @@ def remove_operation(self, op_id): def unsubscribe(self, op_id): async_iterator = self.remove_operation(op_id) - if hasattr(async_iterator, 'dispose'): + if hasattr(async_iterator, "dispose"): async_iterator.dispose() return async_iterator @@ -84,12 +91,16 @@ def process_message(self, connection_context, parsed_message): elif op_type == GQL_CONNECTION_TERMINATE: return self.on_connection_terminate(connection_context, op_id) - elif op_type == GQL_START: + elif op_type == ( + GQL_SUBSCRIBE if connection_context.transport_ws_protocol else GQL_START + ): assert isinstance(payload, dict), "The payload must be a dict" params = self.get_graphql_params(connection_context, payload) return self.on_start(connection_context, op_id, params) - elif op_type == GQL_STOP: + elif op_type == ( + GQL_COMPLETE if connection_context.transport_ws_protocol else GQL_STOP + ): return self.on_stop(connection_context, op_id) else: @@ -142,7 +153,12 @@ def build_message(self, id, op_type, payload): def send_execution_result(self, connection_context, op_id, execution_result): result = self.execution_result_to_dict(execution_result) - return self.send_message(connection_context, op_id, GQL_DATA, result) + return self.send_message( + connection_context, + op_id, + GQL_NEXT if connection_context.transport_ws_protocol else GQL_DATA, + result, + ) def execution_result_to_dict(self, execution_result): result = OrderedDict() diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 8b57a60..8813446 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -1,5 +1,6 @@ GRAPHQL_WS = "graphql-ws" WS_PROTOCOL = GRAPHQL_WS +TRANSPORT_WS_PROTOCOL = "graphql-transport-ws" GQL_CONNECTION_INIT = "connection_init" # Client -> Server GQL_CONNECTION_ACK = "connection_ack" # Server -> Client @@ -8,8 +9,10 @@ # NOTE: This one here don't follow the standard due to connection optimization GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client -GQL_START = "start" # Client -> Server -GQL_DATA = "data" # Server -> Client +GQL_START = "start" # Client -> Server (graphql-ws) +GQL_SUBSCRIBE = "subscribe" # Client -> Server (graphql-transport-ws START equivalent) +GQL_DATA = "data" # Server -> Client (graphql-ws) +GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent) GQL_ERROR = "error" # Server -> Client -GQL_COMPLETE = "complete" # Server -> Client -GQL_STOP = "stop" # Client -> Server +GQL_COMPLETE = "complete" # Server -> Client (and Client -> Server for graphql-transport-ws STOP equivalent) +GQL_STOP = "stop" # Client -> Server (graphql-ws only) diff --git a/graphql_ws/django/consumers.py b/graphql_ws/django/consumers.py index b1c64d1..1dc10ab 100644 --- a/graphql_ws/django/consumers.py +++ b/graphql_ws/django/consumers.py @@ -2,21 +2,25 @@ from channels.generic.websocket import AsyncJsonWebsocketConsumer -from ..constants import WS_PROTOCOL +from ..constants import TRANSPORT_WS_PROTOCOL, WS_PROTOCOL from .subscriptions import subscription_server class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer): - async def connect(self): self.connection_context = None - if WS_PROTOCOL in self.scope["subprotocols"]: - self.connection_context = await subscription_server.handle( - ws=self, request_context=self.scope - ) - await self.accept(subprotocol=WS_PROTOCOL) - else: + found_protocol = None + for protocol in [WS_PROTOCOL, TRANSPORT_WS_PROTOCOL]: + if protocol in self.scope["subprotocols"]: + found_protocol = protocol + break + if not found_protocol: await self.close() + return + self.connection_context = await subscription_server.handle( + ws=self, request_context=self.scope + ) + await self.accept(subprotocol=found_protocol) async def disconnect(self, code): if self.connection_context: From 50f43570e45acf9339fce7e949069bddad49a270 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 3 May 2021 09:49:42 +1200 Subject: [PATCH 2/6] Fix flake line length --- graphql_ws/constants.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/graphql_ws/constants.py b/graphql_ws/constants.py index 8813446..2952296 100644 --- a/graphql_ws/constants.py +++ b/graphql_ws/constants.py @@ -14,5 +14,6 @@ GQL_DATA = "data" # Server -> Client (graphql-ws) GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent) GQL_ERROR = "error" # Server -> Client -GQL_COMPLETE = "complete" # Server -> Client (and Client -> Server for graphql-transport-ws STOP equivalent) +GQL_COMPLETE = "complete" # Server -> Client +# (and Client -> Server for graphql-transport-ws STOP equivalent) GQL_STOP = "stop" # Client -> Server (graphql-ws only) From 9620f3aacab2bc4ef58ed23ff861d7ddc51d009a Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 3 May 2021 09:49:59 +1200 Subject: [PATCH 3/6] Fix async tests --- graphql_ws/base.py | 2 ++ setup.cfg | 3 +-- tests/test_base_async.py | 40 ++++++++++++++++++++++------------------ tests/test_graphql_ws.py | 2 +- 4 files changed, 26 insertions(+), 21 deletions(-) diff --git a/graphql_ws/base.py b/graphql_ws/base.py index 7bd2f10..e69f12e 100644 --- a/graphql_ws/base.py +++ b/graphql_ws/base.py @@ -23,6 +23,8 @@ class ConnectionClosedException(Exception): class BaseConnectionContext(object): + transport_ws_protocol = False + def __init__(self, ws, request_context=None): self.ws = ws self.operations = {} diff --git a/setup.cfg b/setup.cfg index 3d07a80..ded02e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,9 +50,8 @@ test = pytest-asyncio; python_version>="3.4" graphene>=2.0,<3 gevent - graphene>=2.0 graphene_django - mock; python_version<"3" + mock; python_version<"3.8" django==1.11.*; python_version<"3" channels==1.*; python_version<"3" django==2.*; python_version>="3" diff --git a/tests/test_base_async.py b/tests/test_base_async.py index d62eda5..50c4309 100644 --- a/tests/test_base_async.py +++ b/tests/test_base_async.py @@ -10,9 +10,10 @@ pytestmark = pytest.mark.asyncio -class AsyncMock(mock.MagicMock): - async def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) +try: + from unittest.mock import AsyncMock # Python 3.8+ +except ImportError: + from mock import AsyncMock class TstServer(base_async.BaseAsyncSubscriptionServer): @@ -26,75 +27,78 @@ def server(): async def test_terminate(server: TstServer): - context = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) await server.on_connection_terminate(connection_context=context, op_id=1) context.close.assert_called_with(1011) async def test_send_error(server: TstServer): - context = AsyncMock() - context.has_operation = mock.Mock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) await server.send_error(connection_context=context, op_id=1, error="test error") context.send.assert_called_with( {"id": 1, "type": "error", "payload": {"message": "test error"}} ) -async def test_message(server): +async def test_message(server: TstServer): server.process_message = AsyncMock() - context = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} await server.on_message(context, msg) server.process_message.assert_called_with(context, msg) -async def test_message_str(server): +async def test_message_str(server: TstServer): server.process_message = AsyncMock() - context = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""} await server.on_message(context, json.dumps(msg)) server.process_message.assert_called_with(context, msg) -async def test_message_invalid(server): +async def test_message_invalid(server: TstServer): server.send_error = AsyncMock() - await server.on_message(connection_context=None, message="'not-json") + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) + await server.on_message(context, message="'not-json") assert server.send_error.called -async def test_resolver(server): +async def test_resolver(server: TstServer): server.send_message = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) result = mock.Mock() result.data = {"test": [1, 2]} result.errors = None await server.send_execution_result( - connection_context=None, op_id=1, execution_result=result + context, op_id=1, execution_result=result ) assert server.send_message.called @pytest.mark.asyncio -async def test_resolver_with_promise(server): +async def test_resolver_with_promise(server: TstServer): server.send_message = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) result = mock.Mock() result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]} result.errors = None await server.send_execution_result( - connection_context=None, op_id=1, execution_result=result + context, op_id=1, execution_result=result ) assert server.send_message.called assert result.data == {'test': [1, 2]} -async def test_resolver_with_nested_promise(server): +async def test_resolver_with_nested_promise(server: TstServer): server.send_message = AsyncMock() + context = AsyncMock(spec=base_async.BaseAsyncConnectionContext) result = mock.Mock() inner = promise.Promise(lambda resolve, reject: resolve(2)) outer = promise.Promise(lambda resolve, reject: resolve({'in': inner})) result.data = {"test": [1, outer]} result.errors = None await server.send_execution_result( - connection_context=None, op_id=1, execution_result=result + context, op_id=1, execution_result=result ) assert server.send_message.called assert result.data == {'test': [1, {'in': 2}]} diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 3b85c49..6302cb1 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -165,7 +165,7 @@ def test_build_message_partial(ss): ss.build_message(id=None, op_type=None, payload=None) -def test_send_execution_result(ss): +def test_send_execution_result(ss, cc): ss.execution_result_to_dict = mock.Mock() ss.execution_result_to_dict.return_value = {"res": "ult"} ss.send_message = mock.Mock() From 2f0eb21ca2e5df5dcf832619455ca37c71aa21a9 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Mon, 3 May 2021 11:14:25 +1200 Subject: [PATCH 4/6] Extend tests for transport_ws_protocol GQL types --- tests/test_graphql_ws.py | 45 ++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/test_graphql_ws.py b/tests/test_graphql_ws.py index 6302cb1..bb30d4a 100644 --- a/tests/test_graphql_ws.py +++ b/tests/test_graphql_ws.py @@ -77,15 +77,20 @@ def test_terminate(self, ss, cc): ss.process_message(cc, {"id": "1", "type": constants.GQL_CONNECTION_TERMINATE}) ss.on_connection_terminate.assert_called_with(cc, "1") - def test_start(self, ss, cc): + @pytest.mark.parametrize( + "transport_ws_protocol,expected_type", + ((False, constants.GQL_START), (True, constants.GQL_SUBSCRIBE)), + ) + def test_start(self, ss, cc, transport_ws_protocol, expected_type): ss.get_graphql_params = mock.Mock() ss.get_graphql_params.return_value = {"params": True} cc.has_operation = mock.Mock() cc.has_operation.return_value = False + cc.transport_ws_protocol = transport_ws_protocol ss.unsubscribe = mock.Mock() ss.on_start = mock.Mock() ss.process_message( - cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}} + cc, {"id": "1", "type": expected_type, "payload": {"a": "b"}} ) assert not ss.unsubscribe.called ss.on_start.assert_called_with(cc, "1", {"params": True}) @@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc): assert isinstance(ss.send_error.call_args[0][2], Exception) assert not ss.on_start.called - def test_stop(self, ss, cc): + @pytest.mark.parametrize( + "transport_ws_protocol,stop_type,invalid_stop_type", + ( + (False, constants.GQL_STOP, constants.GQL_COMPLETE), + (True, constants.GQL_COMPLETE, constants.GQL_STOP), + ), + ) + def test_stop( + self, + ss, + cc, + transport_ws_protocol, + stop_type, + invalid_stop_type, + ): ss.on_stop = mock.Mock() - ss.process_message(cc, {"id": "1", "type": constants.GQL_STOP}) + ss.send_error = mock.Mock() + cc.transport_ws_protocol = transport_ws_protocol + + ss.process_message(cc, {"id": "1", "type": invalid_stop_type}) + assert ss.send_error.called + assert ss.send_error.call_args[0][:2] == (cc, "1") + assert isinstance(ss.send_error.call_args[0][2], Exception) + assert not ss.on_stop.called + + ss.process_message(cc, {"id": "1", "type": stop_type}) ss.on_stop.assert_called_with(cc, "1") def test_invalid(self, ss, cc): @@ -165,13 +193,18 @@ def test_build_message_partial(ss): ss.build_message(id=None, op_type=None, payload=None) -def test_send_execution_result(ss, cc): +@pytest.mark.parametrize( + "transport_ws_protocol,expected_type", + ((False, constants.GQL_DATA), (True, constants.GQL_NEXT)), +) +def test_send_execution_result(ss, cc, transport_ws_protocol, expected_type): + cc.transport_ws_protocol = transport_ws_protocol ss.execution_result_to_dict = mock.Mock() ss.execution_result_to_dict.return_value = {"res": "ult"} ss.send_message = mock.Mock() ss.send_message.return_value = "returned" assert "returned" == ss.send_execution_result(cc, "1", "result") - ss.send_message.assert_called_with(cc, "1", constants.GQL_DATA, {"res": "ult"}) + ss.send_message.assert_called_with(cc, "1", expected_type, {"res": "ult"}) def test_execution_result_to_dict(ss): From 625e82ba9ba8449639fff4bf8b2293eb01eb9791 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Thu, 16 Dec 2021 16:47:15 +1300 Subject: [PATCH 5/6] Add middleware to the graphql params from the graphene django settings --- graphql_ws/django/subscriptions.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 086445f..372360f 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -7,7 +7,7 @@ class ChannelsConnectionContext(BaseAsyncConnectionContext): def __init__(self, *args, **kwargs): - super(ChannelsConnectionContext, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.socket_closed = False async def send(self, data): @@ -35,5 +35,11 @@ async def handle(self, ws, request_context=None): await self.on_open(connection_context) return connection_context + def get_graphql_params(self, connection_context, payload): + params = super().get_graphql_params(connection_context, payload) + if graphene_settings.MIDDLEWARE: + params["middleware"] = graphene_settings.MIDDLEWARE + return params + subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA) From 50980b3df5f53cc0a958f46236074e8b5008ebb5 Mon Sep 17 00:00:00 2001 From: Chris Beaven Date: Fri, 17 Dec 2021 17:06:43 +1300 Subject: [PATCH 6/6] Avoid using promises for the middleware wrappers --- graphql_ws/django/subscriptions.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/graphql_ws/django/subscriptions.py b/graphql_ws/django/subscriptions.py index 372360f..7aa8c45 100644 --- a/graphql_ws/django/subscriptions.py +++ b/graphql_ws/django/subscriptions.py @@ -1,5 +1,8 @@ from graphene_django.settings import graphene_settings -from ..base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer +from graphql import MiddlewareManager + +from ..base_async import (BaseAsyncConnectionContext, + BaseAsyncSubscriptionServer) from ..observable_aiter import setup_observable_extension setup_observable_extension() @@ -37,8 +40,13 @@ async def handle(self, ws, request_context=None): def get_graphql_params(self, connection_context, payload): params = super().get_graphql_params(connection_context, payload) - if graphene_settings.MIDDLEWARE: - params["middleware"] = graphene_settings.MIDDLEWARE + middleware = graphene_settings.MIDDLEWARE + if middleware: + if not isinstance(middleware, MiddlewareManager): + middleware = MiddlewareManager( + *middleware, wrap_in_promise=False + ) + params["middleware"] = middleware return params