Skip to content

Commit

Permalink
partial support for the test client
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Jul 27, 2015
1 parent 883e73e commit 8591ef4
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 176 deletions.
2 changes: 1 addition & 1 deletion example/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def disconnect_request():


@socketio.on('connect', namespace='/test')
def test_connect(env):
def test_connect():
emit('my response', {'data': 'Connected', 'count': 0})


Expand Down
87 changes: 48 additions & 39 deletions flask_socketio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from werkzeug.serving import run_with_reloader
from werkzeug._internal import _log

#from test_client import SocketIOTestClient
from .test_client import SocketIOTestClient


class SocketIO(object):
Expand Down Expand Up @@ -61,43 +61,36 @@ def handle_my_custom_event(json):
"""
namespace = namespace or '/'

if namespace in self.exception_handlers or \
self.default_exception_handler is not None:
def decorator(handler):
def _handler(sid, *args):
with self.app.request_context(self.server.environ[sid]):
if 'saved_session' in self.server.environ[sid]:
self._copy_session(self.server.environ[sid]['saved_session'], flask.session)
flask.request.sid = sid
flask.request.namespace = namespace
try:
def decorator(handler):
def _handler(sid, *args):
with self.app.request_context(self.server.environ[sid]):
if 'saved_session' in self.server.environ[sid]:
self._copy_session(
self.server.environ[sid]['saved_session'],
flask.session)
flask.request.sid = sid
flask.request.namespace = namespace
try:
if message == 'connect':
ret = handler()
else:
ret = handler(*args)
except:
err_handler = self.exception_handlers.get(
namespace, self.default_exception_handler)
type, value, traceback = sys.exc_info()
return err_handler(value)
self.server.environ[sid]['saved_session'] = {}
self._copy_session(flask.session, self.server.environ[sid]['saved_session'])
return ret
self.server.on(message, _handler, namespace=namespace)
return decorator
else:
def decorator(handler):
def _handler(sid, *args):
with self.app.request_context(self.server.environ[sid]):
if 'saved_session' in self.server.environ[sid]:
self._copy_session(self.server.environ[sid]['saved_session'], flask.session)
flask.request.sid = sid
flask.request.namespace = namespace
ret = handler(*args)
self.server.environ[sid]['saved_session'] = {}
self._copy_session(flask.session, self.server.environ[sid]['saved_session'])
return ret
self.server.on(message, _handler, namespace=namespace)
return decorator

def on_error(self, namespace=''):
except:
err_handler = self.exception_handlers.get(
namespace, self.default_exception_handler)
if err_handler is None:
raise
type, value, traceback = sys.exc_info()
return err_handler(value)
self.server.environ[sid]['saved_session'] = {}
self._copy_session(
flask.session,
self.server.environ[sid]['saved_session'])
return ret
self.server.on(message, _handler, namespace=namespace)
return decorator

def on_error(self, namespace=None):
"""Decorator to define a custom error handler for SocketIO events.
This decorator can be applied to a function that acts as an error
Expand All @@ -112,6 +105,7 @@ def chat_error_handler(e):
:param namespace: The namespace for which to register the error
handler. Defaults to the global namespace.
"""
namespace = namespace or '/'
def decorator(exception_handler):
if not callable(exception_handler):
raise ValueError('exception_handler must be callable')
Expand Down Expand Up @@ -153,13 +147,18 @@ def ping():
:param room: Send the message to all the users in the given room. If
this parameter is not included, the event is sent to
all connected users.
:param callback: If given, this function will be called to acknowledge
that the client has received the message. The
arguments that will be passed to the function are
those provided by the client. Callback functions can
only be used when addressing an individual client.
"""
# TODO: handle skip_sid
self.server.emit(event, *args, namespace=kwargs.get('namespace', '/'),
room=kwargs.get('room'),
callback=kwargs.get('callback'))

def send(self, data, json=False, namespace=None, room=None):
def send(self, data, json=False, namespace=None, room=None, callback=None):
"""Send a server-generated SocketIO message.
This function sends a simple SocketIO message to one or more connected
Expand All @@ -176,8 +175,18 @@ def send(self, data, json=False, namespace=None, room=None):
:param room: Send the message only to the users in the given room. If
this parameter is not included, the message is sent to
all connected users.
:param callback: If given, this function will be called to acknowledge
that the client has received the message. The
arguments that will be passed to the function are
those provided by the client. Callback functions can
only be used when addressing an individual client.
"""
self.server.send(data, namespace=namespace, room=room)
if json:
self.emit('json', data, namespace=namespace, room=room,
callback=callback)
else:
self.emit('message', data, namespace=namespace, room=room,
callback=callback)

def close_room(self, room, namespace='/'):
"""Close a room.
Expand Down
136 changes: 42 additions & 94 deletions flask_socketio/test_client.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,117 +3,65 @@
unit tests.
"""

import uuid

class TestServer(object):
counter = 0
from socketio import packet
from werkzeug.test import EnvironBuilder

def __init__(self):
self.sockets = {}

def new_socket(self):
socket = TestSocket(self, self.counter)
self.sockets[self.counter] = socket
self.counter += 1
return socket
_queue = {}

def remove_socket(self, socket):
for id, s in self.sockets.items():
if s == socket:
del self.sockets[id]
return


class TestSocket(object):
def __init__(self, server, sessid):
self.server = server
self.sessid = sessid
self.active_ns = {}

def __getitem__(self, ns_name):
return self.active_ns[ns_name]


class TestBaseNamespace(object):
def __init__(self, ns_name, socket, request=None):
from werkzeug.test import EnvironBuilder
self.environ = EnvironBuilder().get_environ()
self.ns_name = ns_name
self.socket = socket
self.request = request
self.session = {}
self.received = []
self.initialize()

def initialize(self):
pass

def recv_connect(self):
pass

def recv_disconnect(self):
pass

def emit(self, event, *args, **kwargs):
self.received.append({'name': event, 'args': args})
callback = kwargs.pop('callback', None)
if callback:
callback()

def send(self, message, json=False, callback=None):
if not json:
self.received.append({'name': 'message', 'args': message})
def _mock_send_packet(sid, pkt):
global _queue
if sid not in _queue:
_queue[sid] = []
if pkt.packet_type == packet.EVENT or \
pkt.packet_type == packet.BINARY_EVENT:
if pkt.data[0] == 'message' or pkt.data[0] == 'json':
_queue[sid].append({'name': pkt.data[0], 'args': pkt.data[1],
'namespace': pkt.namespace or '/'})
else:
self.received.append({'name': 'json', 'args': message})
if callback:
callback()
_queue[sid].append({'name': pkt.data[0], 'args': pkt.data[1:],
'namespace': pkt.namespace or '/'})


class SocketIOTestClient(object):
server = TestServer()

def __init__(self, app, socketio, namespace=''):
def __init__(self, app, socketio, namespace=None):
self.sid = uuid.uuid4().hex
self.socketio = socketio
self.socketio.server = self.server
self.socket = self.server.new_socket()
self.connect(app, namespace)

def __del__(self):
self.server.remove_socket(self.socket)
socketio.server._send_packet = _mock_send_packet
socketio.server.environ[self.sid] = {}
self.connect(namespace)

def connect(self, app, namespace=None):
if self.socket.active_ns.get(namespace):
self.disconnect(namespace)
if namespace is None or namespace == '/':
namespace = ''
self.socket.active_ns[namespace] = \
self.socketio._get_namespaces(
TestBaseNamespace)[namespace](namespace, self.socket, app)
self.socket[namespace].recv_connect()
def connect(self, namespace=None):
environ = EnvironBuilder('/socket.io').get_environ()
self.socketio.server._handle_eio_connect(self.sid, environ)
if namespace is not None and namespace != '/':
pkt = packet.Packet(packet.CONNECT, namespace=namespace)
self.socketio.server._handle_eio_message(self.sid, pkt.encode())

def disconnect(self, namespace=None):
if namespace is None or namespace == '/':
namespace = ''
if self.socket[namespace]:
self.socket[namespace].recv_disconnect()
del self.socket.active_ns[namespace]
pkt = packet.Packet(packet.DISCONNECT, namespace=namespace)
self.socketio.server._handle_eio_message(self.sid, pkt.encode())

def emit(self, event, *args, **kwargs):
namespace = kwargs.pop('namespace', None)
if namespace is None or namespace == '/':
namespace = ''
return self.socket[namespace].process_event({'name': event, 'args': args})
pkt = packet.Packet(packet.EVENT, data=[event] + list(args),
namespace=namespace, binary=False)
self.socketio.server._handle_eio_message(self.sid, pkt.encode())

def send(self, message, json=False, namespace=None):
if namespace is None or namespace == '/':
namespace = ''
if not json:
return self.socket[namespace].recv_message(message)
def send(self, data, json=False, namespace=None):
if json:
msg = 'json'
else:
return self.socket[namespace].recv_json(message)
msg = 'message'
return self.emit(msg, data, namespace=namespace)

def get_received(self, namespace=None):
if namespace is None or namespace == '/':
namespace = ''
received = self.socket[namespace].received
self.socket[namespace].received = []
return received
if self.sid not in _queue:
return []
namespace = namespace or '/'
r = [pkt for pkt in _queue[self.sid] if pkt['namespace'] == namespace]
_queue[self.sid] = [pkt for pkt in _queue[self.sid] if pkt not in r]
return r
Loading

0 comments on commit 8591ef4

Please sign in to comment.