diff --git a/.gitignore b/.gitignore index c70592c..cefd2de 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ pip-log.txt .pydevproject *.sublime-workspace *.sw[op] +env/ diff --git a/.travis.yml b/.travis.yml index ed2b442..3d34982 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,11 +9,11 @@ python: - pypy matrix: - # include test for flake8 include: - python: 3.6 script: tox -e flake8 + install: - pip install cython tox diff --git a/examples/asyncio_echo/client.py b/examples/asyncio_echo/client.py new file mode 100644 index 0000000..b5274fe --- /dev/null +++ b/examples/asyncio_echo/client.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +import thriftpy +import asyncio +from thriftpy.rpc import make_aio_client + + +echo_thrift = thriftpy.load("echo.thrift", module_name="echo_thrift") + + +async def main(): + client = await make_aio_client( + echo_thrift.EchoService, '127.0.0.1', 6000) + print(await client.echo('hello, world')) + client.close() + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) diff --git a/examples/asyncio_echo/echo.thrift b/examples/asyncio_echo/echo.thrift new file mode 100644 index 0000000..bad8a3e --- /dev/null +++ b/examples/asyncio_echo/echo.thrift @@ -0,0 +1,7 @@ +# ping service demo +service EchoService { + /* + * Sexy c style comment + */ + string echo(1: string param), +} diff --git a/examples/asyncio_echo/server.py b/examples/asyncio_echo/server.py new file mode 100644 index 0000000..79d3da1 --- /dev/null +++ b/examples/asyncio_echo/server.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +import asyncio +import thriftpy + +from thriftpy.rpc import make_aio_server + +echo_thrift = thriftpy.load("echo.thrift", module_name="echo_thrift") + + +class Dispatcher(object): + async def echo(self, param): + print(param) + await asyncio.sleep(0.1) + return param + + +def main(): + server = make_aio_server( + echo_thrift.EchoService, Dispatcher(), '127.0.0.1', 6000) + server.serve() + + +if __name__ == '__main__': + main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..514c604 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import sys + +collect_ignore = ["setup.py"] +if sys.version_info < (3, 5): + collect_ignore.append("test_aio.py") diff --git a/tests/test_aio.py b/tests/test_aio.py new file mode 100644 index 0000000..57a848b --- /dev/null +++ b/tests/test_aio.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- +import os +import asyncio +# import uvloop +import threading + +# asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +import time + +import pytest + +import thriftpy + +thriftpy.install_import_hook() + +from thriftpy.rpc import make_aio_server, make_aio_client # noqa +from thriftpy.transport import TTransportException # noqa + +addressbook = thriftpy.load(os.path.join(os.path.dirname(__file__), + "addressbook.thrift")) +unix_sock = "/tmp/aio_thriftpy_test.sock" +SSL_PORT = 50442 + + +class Dispatcher: + def __init__(self): + self.ab = addressbook.AddressBook() + self.ab.people = {} + + @asyncio.coroutine + def ping(self): + return True + + @asyncio.coroutine + def hello(self, name): + return "hello " + name + + @asyncio.coroutine + def add(self, person): + self.ab.people[person.name] = person + return True + + @asyncio.coroutine + def remove(self, name): + try: + self.ab.people.pop(name) + return True + except KeyError: + raise addressbook.PersonNotExistsError( + "{0} not exists".format(name)) + + @asyncio.coroutine + def get(self, name): + try: + return self.ab.people[name] + except KeyError: + raise addressbook.PersonNotExistsError( + "{0} not exists".format(name)) + + @asyncio.coroutine + def book(self): + return self.ab + + @asyncio.coroutine + def get_phonenumbers(self, name, count): + p = [self.ab.people[name].phones[0]] if name in self.ab.people else [] + return p * count + + @asyncio.coroutine + def get_phones(self, name): + phone_numbers = self.ab.people[name].phones + return dict((p.type, p.number) for p in phone_numbers) + + @asyncio.coroutine + def sleep(self, ms): + yield from asyncio.sleep(ms / 1000.0) + return True + + +@pytest.fixture(scope="module") +def aio_server(request): + loop = asyncio.new_event_loop() + server = make_aio_server( + addressbook.AddressBookService, + Dispatcher(), + unix_socket=unix_sock, + loop=loop + ) + st = threading.Thread(target=server.serve) + st.daemon = True + st.start() + time.sleep(0.1) + + +@pytest.fixture(scope="module") +def aio_ssl_server(request): + loop = asyncio.new_event_loop() + ssl_server = make_aio_server( + addressbook.AddressBookService, Dispatcher(), + host='localhost', port=SSL_PORT, + certfile="ssl/server.pem", keyfile="ssl/server.key", loop=loop + ) + st = threading.Thread(target=ssl_server.serve) + st.daemon = True + st.start() + time.sleep(0.1) + + +@pytest.fixture(scope="module") +def person(): + phone1 = addressbook.PhoneNumber() + phone1.type = addressbook.PhoneType.MOBILE + phone1.number = '555-1212' + phone2 = addressbook.PhoneNumber() + phone2.type = addressbook.PhoneType.HOME + phone2.number = '555-1234' + + # empty struct + phone3 = addressbook.PhoneNumber() + + alice = addressbook.Person() + alice.name = "Alice" + alice.phones = [phone1, phone2, phone3] + alice.created_at = int(time.time()) + + return alice + + +async def client(timeout=3000): + return await make_aio_client( + addressbook.AddressBookService, + unix_socket=unix_sock, socket_timeout=timeout + ) + + +async def ssl_client(timeout=3000): + return await make_aio_client( + addressbook.AddressBookService, + host='localhost', port=SSL_PORT, + socket_timeout=timeout, + cafile="ssl/CA.pem", certfile="ssl/client.crt", + keyfile="ssl/client.key") + + +@pytest.mark.asyncio +async def test_void_api(aio_server): + c = await client() + assert await c.ping() is None + c.close() + + +@pytest.mark.asyncio +async def test_void_api_with_ssl(aio_ssl_server): + c = await ssl_client() + assert await c.ping() is None + c.close() + + +@pytest.mark.asyncio +async def test_string_api(aio_server): + c = await client() + assert await c.hello("world") == "hello world" + c.close() + + +@pytest.mark.asyncio +async def test_string_api_with_ssl(aio_ssl_server): + c = await client() + assert await c.hello("world") == "hello world" + c.close() + + +@pytest.mark.asyncio +async def test_huge_res(aio_server): + c = await client() + big_str = "world" * 100000 + assert await c.hello(big_str) == "hello " + big_str + c.close() + + +@pytest.mark.asyncio +async def test_huge_res_with_ssl(aio_ssl_server): + c = await ssl_client() + big_str = "world" * 100000 + assert await c.hello(big_str) == "hello " + big_str + c.close() + + +@pytest.mark.asyncio +async def test_tstruct_req(person): + c = await client() + assert await c.add(person) is True + c.close() + + +@pytest.mark.asyncio +async def test_tstruct_req_with_ssl(person): + c = await ssl_client() + assert await c.add(person) is True + c.close() + + +@pytest.mark.asyncio +async def test_tstruct_res(person): + c = await client() + assert person == await c.get("Alice") + c.close() + + +@pytest.mark.asyncio +async def test_tstruct_res_with_ssl(person): + c = await ssl_client() + assert person == await c.get("Alice") + c.close() + + +@pytest.mark.asyncio +async def test_complex_tstruct(): + c = await client() + assert len(await c.get_phonenumbers("Alice", 0)) == 0 + assert len(await c.get_phonenumbers("Alice", 1000)) == 1000 + c.close() + + +@pytest.mark.asyncio +async def test_complex_tstruct_with_ssl(): + c = await ssl_client() + assert len(await c.get_phonenumbers("Alice", 0)) == 0 + assert len(await c.get_phonenumbers("Alice", 1000)) == 1000 + c.close() + + +@pytest.mark.asyncio +async def test_exception(): + with pytest.raises(addressbook.PersonNotExistsError): + c = await client() + await c.remove("Bob") + + +@pytest.mark.asyncio +async def test_exception_iwth_ssl(): + with pytest.raises(addressbook.PersonNotExistsError): + c = await ssl_client() + await c.remove("Bob") + + +@pytest.mark.asyncio +async def test_client_socket_timeout(): + with pytest.raises(asyncio.TimeoutError): + try: + c = await ssl_client(timeout=500) + await c.sleep(1000) + except: + c.close() + raise + + +@pytest.mark.asyncio +async def test_ssl_socket_timeout(): + # SSL socket timeout raises socket.timeout since Python 3.2. + # http://bugs.python.org/issue10272 + with pytest.raises(asyncio.TimeoutError): + try: + c = await ssl_client(timeout=500) + await c.sleep(1000) + except: + c.close() + raise + + +@pytest.mark.asyncio +async def test_client_connect_timeout(): + with pytest.raises(TTransportException): + c = await make_aio_client( + addressbook.AddressBookService, + unix_socket='/tmp/test.sock', + connect_timeout=1000 + ) + await c.hello('test') diff --git a/tests/test_framed_transport.py b/tests/test_framed_transport.py index 68e37dc..38ee2d9 100644 --- a/tests/test_framed_transport.py +++ b/tests/test_framed_transport.py @@ -2,6 +2,7 @@ from __future__ import absolute_import +import sys import logging import socket import threading @@ -10,6 +11,7 @@ from os import path from unittest import TestCase +import pytest from tornado import ioloop import thriftpy @@ -83,6 +85,7 @@ def setUp(self): time.sleep(0.1) self.client = self.mk_client() + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") def test_able_to_communicate(self): dennis = addressbook.Person(name='Dennis Ritchie') success = self.client.add(dennis) @@ -90,6 +93,7 @@ def test_able_to_communicate(self): success = self.client.add(dennis) assert not success + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") def test_zero_length_string(self): dennis = addressbook.Person(name='') success = self.client.add(dennis) diff --git a/tests/test_tornado.py b/tests/test_tornado.py index cbe4c9c..debeec7 100644 --- a/tests/test_tornado.py +++ b/tests/test_tornado.py @@ -2,10 +2,12 @@ from __future__ import absolute_import +import sys from os import path import logging import socket +import pytest from tornado import gen, testing import thriftpy @@ -86,6 +88,7 @@ def tearDown(self): super(TornadoRPCTestCase, self).tearDown() @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") def test_synchronous_result(self): dennis = addressbook.Person(name='Dennis Ritchie') success = yield self.client.add(dennis) @@ -96,6 +99,7 @@ def test_synchronous_result(self): assert person.name == dennis.name @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") def test_synchronous_exception(self): exc = None try: @@ -106,6 +110,7 @@ def test_synchronous_exception(self): assert isinstance(exc, addressbook.PersonNotExistsError) @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") def test_asynchronous_result(self): dennis = addressbook.Person(name='Dennis Ritchie') yield self.client.add(dennis) @@ -113,6 +118,7 @@ def test_asynchronous_result(self): assert success @testing.gen_test + @pytest.mark.skipif(sys.version_info[:2] == (2, 6), reason="not support") def test_asynchronous_exception(self): exc = None try: diff --git a/thriftpy/__init__.py b/thriftpy/__init__.py index 8d04fd0..5a0402e 100644 --- a/thriftpy/__init__.py +++ b/thriftpy/__init__.py @@ -7,5 +7,5 @@ __version__ = '0.3.9' __python__ = sys.version_info -__all__ = ["install_import_hook", "remove_import_hook", "load", "load_module", +__all__ = ["install_import_hook", "remove_import_hook", "load", "load_module", "load_fp"] diff --git a/thriftpy/_compat.py b/thriftpy/_compat.py index e0c7f0b..cd3b7e0 100644 --- a/thriftpy/_compat.py +++ b/thriftpy/_compat.py @@ -13,6 +13,7 @@ import sys PY3 = sys.version_info[0] == 3 +PY35 = sys.version_info >= (3, 5) PYPY = "__pypy__" in sys.modules UNIX = platform.system() in ("Linux", "Darwin") diff --git a/thriftpy/contrib/aio/__init__.py b/thriftpy/contrib/aio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/thriftpy/contrib/aio/client.py b/thriftpy/contrib/aio/client.py new file mode 100644 index 0000000..c68256a --- /dev/null +++ b/thriftpy/contrib/aio/client.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +import asyncio +import functools +from thriftpy.thrift import args2kwargs +from thriftpy.thrift import TApplicationException, TMessageType + + +class TAsyncClient: + + def __init__(self, service, iprot, oprot=None): + self._service = service + self._iprot = self._oprot = iprot + if oprot is not None: + self._oprot = oprot + self._seqid = 0 + + def __getattr__(self, _api): + if _api in self._service.thrift_services: + return functools.partial(self._req, _api) + + raise AttributeError("{} instance has no attribute '{}'".format( + self.__class__.__name__, _api)) + + def __dir__(self): + return self._service.thrift_services + + @asyncio.coroutine + def _req(self, _api, *args, **kwargs): + _kw = args2kwargs(getattr(self._service, _api + "_args").thrift_spec, + *args) + kwargs.update(_kw) + result_cls = getattr(self._service, _api + "_result") + + yield from self._send(_api, **kwargs) + # wait result only if non-oneway + if not getattr(result_cls, "oneway"): + return (yield from self._recv(_api)) + + @asyncio.coroutine + def _send(self, _api, **kwargs): + self._oprot.write_message_begin(_api, TMessageType.CALL, self._seqid) + args = getattr(self._service, _api + "_args")() + for k, v in kwargs.items(): + setattr(args, k, v) + self._oprot.write_struct(args) + self._oprot.write_message_end() + yield from self._oprot.trans.flush() + + @asyncio.coroutine + def _recv(self, _api): + fname, mtype, rseqid = yield from self._iprot.read_message_begin() + if mtype == TMessageType.EXCEPTION: + x = TApplicationException() + yield from self._iprot.read_struct(x) + yield from self._iprot.read_message_end() + raise x + result = getattr(self._service, _api + "_result")() + yield from self._iprot.read_struct(result) + yield from self._iprot.read_message_end() + + if hasattr(result, "success") and result.success is not None: + return result.success + + # void api without throws + if len(result.thrift_spec) == 0: + return + + # check throws + for k, v in result.__dict__.items(): + if k != "success" and v: + raise v + + # no throws & not void api + if hasattr(result, "success"): + raise TApplicationException(TApplicationException.MISSING_RESULT) + + def close(self): + self._iprot.trans.close() + if self._iprot != self._oprot: + self._oprot.trans.close() diff --git a/thriftpy/contrib/aio/processor.py b/thriftpy/contrib/aio/processor.py new file mode 100644 index 0000000..6e1b56d --- /dev/null +++ b/thriftpy/contrib/aio/processor.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +import asyncio +from thriftpy.thrift import TApplicationException, TType, TMessageType + + +class TAsyncProcessor(object): + + def __init__(self, service, handler): + self._service = service + self._handler = handler + + @asyncio.coroutine + def process_in(self, iprot): + api, type, seqid = yield from iprot.read_message_begin() + if api not in self._service.thrift_services: + yield from iprot.skip(TType.STRUCT) + yield from iprot.read_message_end() + return api, seqid, TApplicationException(TApplicationException.UNKNOWN_METHOD), None # noqa + + args = getattr(self._service, api + "_args")() + yield from iprot.read_struct(args) + yield from iprot.read_message_end() + result = getattr(self._service, api + "_result")() + + # convert kwargs to args + api_args = [args.thrift_spec[k][1] for k in sorted(args.thrift_spec)] + + @asyncio.coroutine + def call(): + f = getattr(self._handler, api) + return (yield from f(*(args.__dict__[k] for k in api_args))) + + return api, seqid, result, call + + @asyncio.coroutine + def send_exception(self, oprot, api, exc, seqid): + oprot.write_message_begin(api, TMessageType.EXCEPTION, seqid) + exc.write(oprot) + oprot.write_message_end() + yield from oprot.trans.flush() + + @asyncio.coroutine + def send_result(self, oprot, api, result, seqid): + oprot.write_message_begin(api, TMessageType.REPLY, seqid) + oprot.write_struct(result) + oprot.write_message_end() + yield from oprot.trans.flush() + + def handle_exception(self, e, result): + for k in sorted(result.thrift_spec): + if result.thrift_spec[k][1] == "success": + continue + + _, exc_name, exc_cls, _ = result.thrift_spec[k] + if isinstance(e, exc_cls): + setattr(result, exc_name, e) + break + else: + raise + + def process(self, iprot, oprot): + api, seqid, result, call = yield from self.process_in(iprot) + + if isinstance(result, TApplicationException): + return self.send_exception(oprot, api, result, seqid) + + try: + result.success = yield from call() + except Exception as e: + # raise if api don't have throws + self.handle_exception(e, result) + + if not result.oneway: + yield from self.send_result(oprot, api, result, seqid) diff --git a/thriftpy/contrib/aio/protocol/__init__.py b/thriftpy/contrib/aio/protocol/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/thriftpy/contrib/aio/protocol/binary.py b/thriftpy/contrib/aio/protocol/binary.py new file mode 100644 index 0000000..758db79 --- /dev/null +++ b/thriftpy/contrib/aio/protocol/binary.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import + +import asyncio + +from thriftpy.thrift import TType + +from thriftpy.protocol.exc import TProtocolException +from thriftpy.protocol.binary import ( + VERSION_MASK, + VERSION_1, + TYPE_MASK, + unpack_i8, + unpack_i16, + unpack_i32, + unpack_i64, + unpack_double, + write_message_begin, + write_val +) + + +@asyncio.coroutine +def read_message_begin(inbuf, strict=True): + sz = unpack_i32((yield from inbuf.read(4))) + if sz < 0: + version = sz & VERSION_MASK + if version != VERSION_1: + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in read_message_begin: %d' % (sz)) + name_sz = unpack_i32((yield from inbuf.read(4))) + name = yield from inbuf.read(name_sz) + name = name.decode('utf-8') + + type_ = sz & TYPE_MASK + else: + if strict: + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') + + name = yield from inbuf.read(sz) + type_ = unpack_i8((yield from inbuf.read(1))) + + seqid = unpack_i32((yield from inbuf.read(4))) + + return name, type_, seqid + + +@asyncio.coroutine +def read_field_begin(inbuf): + f_type = unpack_i8((yield from inbuf.read(1))) + if f_type == TType.STOP: + return f_type, 0 + + return f_type, unpack_i16((yield from inbuf.read(2))) + + +@asyncio.coroutine +def read_list_begin(inbuf): + e_type = unpack_i8((yield from inbuf.read(1))) + sz = unpack_i32((yield from inbuf.read(4))) + return e_type, sz + + +@asyncio.coroutine +def read_map_begin(inbuf): + k_type = unpack_i8((yield from inbuf.read(1))) + v_type = unpack_i8((yield from inbuf.read(1))) + sz = unpack_i32((yield from inbuf.read(4))) + return k_type, v_type, sz + + +@asyncio.coroutine +def read_val(inbuf, ttype, spec=None, decode_response=True): + if ttype == TType.BOOL: + return bool(unpack_i8((yield from inbuf.read(1)))) + + elif ttype == TType.BYTE: + return unpack_i8((yield from inbuf.read(1))) + + elif ttype == TType.I16: + return unpack_i16((yield from inbuf.read(2))) + + elif ttype == TType.I32: + return unpack_i32((yield from inbuf.read(4))) + + elif ttype == TType.I64: + return unpack_i64((yield from inbuf.read(8))) + + elif ttype == TType.DOUBLE: + return unpack_double((yield from inbuf.read(8))) + + elif ttype == TType.STRING: + sz = unpack_i32((yield from inbuf.read(4))) + byte_payload = yield from inbuf.read(sz) + + # Since we cannot tell if we're getting STRING or BINARY + # if not asked not to decode, try both + if decode_response: + try: + return byte_payload.decode('utf-8') + except UnicodeDecodeError: + pass + return byte_payload + + elif ttype == TType.SET or ttype == TType.LIST: + if isinstance(spec, tuple): + v_type, v_spec = spec[0], spec[1] + else: + v_type, v_spec = spec, None + + result = [] + r_type, sz = yield from read_list_begin(inbuf) + # the v_type is useless here since we already get it from spec + if r_type != v_type: + for _ in range(sz): + yield from skip(inbuf, r_type) + return [] + + for i in range(sz): + result.append( + (yield from read_val( + inbuf, v_type, v_spec, decode_response + )) + ) + return result + + elif ttype == TType.MAP: + if isinstance(spec[0], int): + k_type = spec[0] + k_spec = None + else: + k_type, k_spec = spec[0] + + if isinstance(spec[1], int): + v_type = spec[1] + v_spec = None + else: + v_type, v_spec = spec[1] + + result = {} + sk_type, sv_type, sz = yield from read_map_begin(inbuf) + if sk_type != k_type or sv_type != v_type: + for _ in range(sz): + yield from skip(inbuf, sk_type) + yield from skip(inbuf, sv_type) + return {} + + for i in range(sz): + k_val = yield from read_val(inbuf, k_type, k_spec, decode_response) + v_val = yield from read_val(inbuf, v_type, v_spec, decode_response) + result[k_val] = v_val + + return result + + elif ttype == TType.STRUCT: + obj = spec() + yield from read_struct(inbuf, obj, decode_response) + return obj + + +@asyncio.coroutine +def read_struct(inbuf, obj, decode_response=True): + while True: + f_type, fid = yield from read_field_begin(inbuf) + if f_type == TType.STOP: + break + + if fid not in obj.thrift_spec: + yield from skip(inbuf, f_type) + continue + + if len(obj.thrift_spec[fid]) == 3: + sf_type, f_name, f_req = obj.thrift_spec[fid] + f_container_spec = None + else: + sf_type, f_name, f_container_spec, f_req = obj.thrift_spec[fid] + + # it really should equal here. but since we already wasted + # space storing the duplicate info, let's check it. + if f_type != sf_type: + yield from skip(inbuf, f_type) + continue + + _buf = yield from read_val( + inbuf, f_type, f_container_spec, decode_response) + setattr(obj, f_name, _buf) + + +@asyncio.coroutine +def skip(inbuf, ftype): + if ftype == TType.BOOL or ftype == TType.BYTE: + yield from inbuf.read(1) + + elif ftype == TType.I16: + yield from inbuf.read(2) + + elif ftype == TType.I32: + yield from inbuf.read(4) + + elif ftype == TType.I64: + yield from inbuf.read(8) + + elif ftype == TType.DOUBLE: + yield from inbuf.read(8) + + elif ftype == TType.STRING: + _size = yield from inbuf.read(4) + yield from inbuf.read(unpack_i32(_size)) + + elif ftype == TType.SET or ftype == TType.LIST: + v_type, sz = yield from read_list_begin(inbuf) + for i in range(sz): + yield from skip(inbuf, v_type) + + elif ftype == TType.MAP: + k_type, v_type, sz = yield from read_map_begin(inbuf) + for i in range(sz): + yield from skip(inbuf, k_type) + yield from skip(inbuf, v_type) + + elif ftype == TType.STRUCT: + while True: + f_type, fid = yield from read_field_begin(inbuf) + if f_type == TType.STOP: + break + yield from skip(inbuf, f_type) + + +class TAsyncBinaryProtocol(object): + """Binary implementation of the Thrift protocol driver.""" + + def __init__(self, trans, + strict_read=True, strict_write=True, + decode_response=True): + self.trans = trans + self.strict_read = strict_read + self.strict_write = strict_write + self.decode_response = decode_response + + @asyncio.coroutine + def skip(self, ttype): + yield from skip(self.trans, ttype) + + @asyncio.coroutine + def read_message_begin(self): + api, ttype, seqid = yield from read_message_begin( + self.trans, strict=self.strict_read) + return api, ttype, seqid + + @asyncio.coroutine + def read_message_end(self): + pass + + def write_message_begin(self, name, ttype, seqid): + write_message_begin( + self.trans, name, ttype, + seqid, strict=self.strict_write + ) + + def write_message_end(self): + pass + + @asyncio.coroutine + def read_struct(self, obj): + return (yield from read_struct(self.trans, obj, self.decode_response)) + + def write_struct(self, obj): + write_val(self.trans, TType.STRUCT, obj) + + +class TAsyncBinaryProtocolFactory(object): + def __init__(self, strict_read=True, strict_write=True, + decode_response=True): + self.strict_read = strict_read + self.strict_write = strict_write + self.decode_response = decode_response + + def get_protocol(self, trans): + return TAsyncBinaryProtocol( + trans, + self.strict_read, + self.strict_write, + self.decode_response + ) diff --git a/thriftpy/contrib/aio/rpc.py b/thriftpy/contrib/aio/rpc.py new file mode 100644 index 0000000..fc240ae --- /dev/null +++ b/thriftpy/contrib/aio/rpc.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*-import warnings +import asyncio +import warnings +from .processor import TAsyncProcessor +from .client import TAsyncClient +from .protocol.binary import TAsyncBinaryProtocolFactory +from .transport.buffered import TAsyncBufferedTransportFactory +from .socket import TAsyncSocket, TAsyncServerSocket +from .server import TAsyncServer + + +@asyncio.coroutine +def make_client(service, host="localhost", port=9090, unix_socket=None, + proto_factory=TAsyncBinaryProtocolFactory(), + trans_factory=TAsyncBufferedTransportFactory(), + socket_timeout=3000, connect_timeout=None, + cafile=None, ssl_context=None, + certfile=None, keyfile=None, + validate=True): + if unix_socket: + socket = TAsyncSocket(unix_socket=unix_socket) + if certfile: + warnings.warn("SSL only works with host:port, not unix_socket.") + elif host and port: + socket = TAsyncSocket( + host, port, + socket_timeout=socket_timeout, connect_timeout=connect_timeout, + cafile=cafile, ssl_context=ssl_context, + certfile=certfile, keyfile=keyfile, validate=validate) + else: + raise ValueError("Either host/port or unix_socket must be provided.") + + transport = trans_factory.get_transport(socket) + protocol = proto_factory.get_protocol(transport) + yield from transport.open() + return TAsyncClient(service, protocol) + + +def make_server(service, handler, + host="localhost", port=9090, unix_socket=None, + proto_factory=TAsyncBinaryProtocolFactory(), + trans_factory=TAsyncBufferedTransportFactory(), + client_timeout=3000, certfile=None, + keyfile=None, ssl_context=None, loop=None): + processor = TAsyncProcessor(service, handler) + + if unix_socket: + server_socket = TAsyncServerSocket(unix_socket=unix_socket) + if certfile: + warnings.warn("SSL only works with host:port, not unix_socket.") + elif host and port: + server_socket = TAsyncServerSocket( + host=host, port=port, + client_timeout=client_timeout, + certfile=certfile, keyfile=keyfile, ssl_context=ssl_context) + else: + raise ValueError("Either host/port or unix_socket must be provided.") + + server = TAsyncServer(processor, server_socket, + iprot_factory=proto_factory, + itrans_factory=trans_factory, loop=loop) + return server diff --git a/thriftpy/contrib/aio/server.py b/thriftpy/contrib/aio/server.py new file mode 100644 index 0000000..060aa7d --- /dev/null +++ b/thriftpy/contrib/aio/server.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +import asyncio +from thriftpy.server import TServer, logger +from thriftpy.transport import TTransportException + + +class TAsyncServer(TServer): + + def __init__(self, *args, **kwargs): + self.loop = kwargs['loop'] + kwargs.pop('loop') + + TServer.__init__( + self, + *args, + **kwargs + ) + self.closed = False + + def serve(self): + self.init_server() + try: + self.loop.run_forever() + except: + self.close() + raise + + def init_server(self): + self.trans.listen() + if not self.loop: + self.loop = asyncio.get_event_loop() + self.server = self.loop.run_until_complete(self.trans.accept(self.handle)) + + @asyncio.coroutine + def handle(self, client): + itrans = self.itrans_factory.get_transport(client) + otrans = self.otrans_factory.get_transport(client) + iprot = self.iprot_factory.get_protocol(itrans) + oprot = self.oprot_factory.get_protocol(otrans) + try: + while not client.reader.at_eof(): + yield from self.processor.process(iprot, oprot) + except TTransportException: + pass + except Exception as x: + logger.exception(x) + + itrans.close() + + @asyncio.coroutine + def close(self): + self.server.close() + yield from self.server.wait_closed() + self.closed = True diff --git a/thriftpy/contrib/aio/socket.py b/thriftpy/contrib/aio/socket.py new file mode 100644 index 0000000..352aa2b --- /dev/null +++ b/thriftpy/contrib/aio/socket.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division + +import ssl +import asyncio +import errno +import os +import socket +import struct +import sys + +from thriftpy.transport import TTransportException +from thriftpy.transport._ssl import ( + create_thriftpy_context, + RESTRICTED_SERVER_CIPHERS, + DEFAULT_CIPHERS +) + + +class TAsyncSocket(object): + """Socket implementation for client side.""" + + def __init__(self, host=None, port=None, unix_socket=None, + sock=None, socket_family=socket.AF_INET, + socket_timeout=3000, connect_timeout=None, + ssl_context=None, validate=True, + cafile=None, capath=None, certfile=None, keyfile=None, + ciphers=DEFAULT_CIPHERS): + """Initialize a TSocket + + TSocket can be initialized in 3 ways: + * host + port. can configure to use AF_INET/AF_INET6 + * unix_socket + * socket. should pass already opened socket here. + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + @param sock(socket) Initialize with opened socket directly. + If this param used, the host, port and unix_socket params will + be ignored. + @param socket_family(str) socket.AF_INET or socket.AF_INET6. only + take effect when using host/port + @param socket_timeout socket timeout in ms + @param connect_timeout connect timeout in ms, only used in + connection, will be set to socket_timeout if not set. + @param validate(bool) Set to False to disable SSL certificate + validation and hostname validation. Default enabled. + @param cafile(str) Path to a file of concatenated CA + certificates in PEM format. + @param capath(str) path to a directory containing several CA + certificates in PEM format, following an OpenSSL specific layout. + @param certfile(str) The certfile string must be the path to a + single file in PEM format containing the certificate as well as + any number of CA certificates needed to establish the + certificate’s authenticity. + @param keyfile(str) The keyfile string, if not present, + the private key will be taken from certfile as well. + @param ciphers(list) The cipher suites to allow + @param ssl_context(SSLContext) Customize the SSLContext, can be used + to persist SSLContext object. Caution it's easy to get wrong, only + use if you know what you're doing. + """ + if sock: + self.raw_sock = sock + elif unix_socket: + self.unix_socket = unix_socket + self.host = None + self.port = None + self.raw_sock = None + self.sock_factory = asyncio.open_unix_connection + else: + self.unix_socket = None + self.host = host + self.port = port + self.raw_sock = None + self.sock_factory = asyncio.open_connection + + self.socket_family = socket_family + self.socket_timeout = socket_timeout / 1000 if socket_timeout else None + self.connect_timeout = connect_timeout / 1000 if connect_timeout \ + else self.socket_timeout + + if ssl_context: + self.ssl_context = ssl_context + self.server_hostname = host + elif certfile or keyfile: + self.server_hostname = host + self.ssl_context = create_thriftpy_context(server_side=False, + ciphers=ciphers) + + if cafile or capath: + self.ssl_context.load_verify_locations(cafile=cafile, + capath=capath) + + if certfile: + self.ssl_context.load_cert_chain(certfile, keyfile=keyfile) + + if not validate: + self.ssl_context.check_hostname = False + self.ssl_context.verify_mode = ssl.CERT_NONE + else: + self.ssl_context = None + self.server_hostname = None + + def _init_sock(self): + if self.unix_socket: + _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + else: + _sock = socket.socket(self.socket_family, socket.SOCK_STREAM) + _sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + # socket options + linger = struct.pack('ii', 0, 0) + _sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger) + _sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + self.raw_sock = _sock + + def set_handle(self, sock): + self.raw_sock = sock + + def set_timeout(self, ms): + """Backward compat api, will bind the timeout to both connect_timeout + and socket_timeout. + """ + self.socket_timeout = ms / 1000 if (ms and ms > 0) else None + self.connect_timeout = self.socket_timeout + + if self.raw_sock is not None: + self.raw_sock.settimeout(self.socket_timeout) + + def is_open(self): + return bool(self.raw_sock) + + @asyncio.coroutine + def open(self): + self._init_sock() + + addr = self.unix_socket or (self.host, self.port) + + try: + if self.connect_timeout: + self.raw_sock.settimeout(self.connect_timeout) + + self.raw_sock.connect(addr) + + if self.socket_timeout: + self.raw_sock.settimeout(self.socket_timeout) + + kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context} + if self.server_hostname: + kwargs['server_hostname'] = self.server_hostname + + self.reader, self.writer = yield from asyncio.wait_for( + self.sock_factory(**kwargs), + self.socket_timeout + ) + + except (socket.error, OSError): + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Could not connect to %s" % str(addr)) + + @asyncio.coroutine + def read(self, sz): + try: + buff = yield from asyncio.wait_for( + self.reader.read(sz), + self.connect_timeout + ) + except socket.error as e: + if (e.args[0] == errno.ECONNRESET and + (sys.platform == 'darwin' or + sys.platform.startswith('freebsd'))): + # freebsd and Mach don't follow POSIX semantic of recv + # and fail with ECONNRESET if peer performed shutdown. + # See corresponding comment and code in TSocket::read() + # in lib/cpp/src/transport/TSocket.cpp. + self.close() + # Trigger the check to raise the END_OF_FILE exception below. + buff = '' + else: + raise + + if len(buff) == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') + return buff + + def write(self, buff): + self.writer.write(buff) + + @asyncio.coroutine + def flush(self): + yield from asyncio.wait_for(self.writer.drain(), self.connect_timeout) + + def close(self): + if not self.raw_sock: + return + + try: + self.writer.close() + self.raw_sock.close() + self.raw_sock = None + except (socket.error, OSError): + pass + + +class TAsyncServerSocket(object): + """Socket implementation for server side.""" + + def __init__(self, host=None, port=None, unix_socket=None, + socket_family=socket.AF_INET, client_timeout=3000, + backlog=128, ssl_context=None, certfile=None, keyfile=None, + ciphers=RESTRICTED_SERVER_CIPHERS): + """Initialize a TServerSocket + + TSocket can be initialized in 2 ways: + * host + port. can configure to use AF_INET/AF_INET6 + * unix_socket + + @param host(str) The host to connect to + @param port(int) The (TCP) port to connect to + @param unix_socket(str) The filename of a unix socket to connect to + @param socket_family(str) socket.AF_INET or socket.AF_INET6. only + take effect when using host/port + @param client_timeout client socket timeout + @param backlog backlog for server socket + @param certfile(str) The server cert pem filename + @param keyfile(str) The server cert key filename + @param ciphers(list) The cipher suites to allow + @param ssl_context(SSLContext) Customize the SSLContext, can be used + to persist SSLContext object. Caution it's easy to get wrong, only + use if you know what you're doing. + """ + if unix_socket: + self.unix_socket = unix_socket + self.host = None + self.port = None + self.sock_factory = asyncio.start_unix_server + else: + self.unix_socket = None + self.host = host + self.port = port + self.sock_factory = asyncio.start_server + + self.socket_family = socket_family + self.client_timeout = client_timeout / 1000 if client_timeout else None + self.backlog = backlog + + if ssl_context: + self.ssl_context = ssl_context + elif certfile: + if not os.access(certfile, os.R_OK): + raise IOError('No such certfile found: %s' % certfile) + + self.ssl_context = create_thriftpy_context(server_side=True, + ciphers=ciphers) + self.ssl_context.load_cert_chain(certfile, keyfile=keyfile) + else: + self.ssl_context = None + + def _init_sock(self): + if self.unix_socket: + # try remove the sock file it already exists + _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + _sock.connect(self.unix_socket) + except (socket.error, OSError) as err: + if err.args[0] == errno.ECONNREFUSED: + os.unlink(self.unix_socket) + else: + _sock = socket.socket(self.socket_family, socket.SOCK_STREAM) + + _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(socket, "SO_REUSEPORT"): + try: + _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except socket.error as err: + if err[0] in (errno.ENOPROTOOPT, errno.EINVAL): + pass + else: + raise + _sock.settimeout(None) + self.raw_sock = _sock + + def listen(self): + self._init_sock() + + addr = self.unix_socket or (self.host, self.port) + self.raw_sock.bind(addr) + self.raw_sock.listen(self.backlog) + + @asyncio.coroutine + def accept(self, callback): + server = yield from self.sock_factory( + lambda reader, writer: asyncio.wait_for( + callback(StreamHandler(reader, writer)), + self.client_timeout + ), + sock=self.raw_sock, + ssl=self.ssl_context + ) + return server + + def close(self): + if not self.raw_sock: + return + + try: + self.raw_sock.shutdown(socket.SHUT_RDWR) + self.raw_sock.close() + except (socket.error, OSError): + pass + + +class StreamHandler(object): + def __init__(self, reader, writer): + self.reader, self.writer = reader, writer + + @asyncio.coroutine + def read(self, sz): + try: + buff = yield from self.reader.read(sz) + except socket.error as e: + if (e.args[0] == errno.ECONNRESET and + (sys.platform == 'darwin' or + sys.platform.startswith('freebsd'))): + # freebsd and Mach don't follow POSIX semantic of recv + # and fail with ECONNRESET if peer performed shutdown. + # See corresponding comment and code in TSocket::read() + # in lib/cpp/src/transport/TSocket.cpp. + self.close() + # Trigger the check to raise the END_OF_FILE exception below. + buff = '' + else: + raise + + if len(buff) == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') + return buff + + def write(self, buff): + self.writer.write(buff) + + @asyncio.coroutine + def flush(self): + yield from self.writer.drain() + + def close(self): + try: + self.writer.close() + except (socket.error, OSError): + pass + + @asyncio.coroutine + def open(self): + pass diff --git a/thriftpy/contrib/aio/transport/__init__.py b/thriftpy/contrib/aio/transport/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/thriftpy/contrib/aio/transport/buffered.py b/thriftpy/contrib/aio/transport/buffered.py new file mode 100644 index 0000000..7c8538b --- /dev/null +++ b/thriftpy/contrib/aio/transport/buffered.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +import asyncio +from io import BytesIO + +from thriftpy.transport import TTransportBase, TTransportException + + +@asyncio.coroutine +def readall(read_fn, sz): + buff = b'' + have = 0 + while have < sz: + chunk = yield from read_fn(sz - have) + have += len(chunk) + buff += chunk + + if len(chunk) == 0: + raise TTransportException(TTransportException.END_OF_FILE, + "End of file reading from transport") + + return buff + + +class TAsyncBufferedTransport(TTransportBase): + """Class that wraps another transport and buffers its I/O. + + The implementation uses a (configurable) fixed-size read buffer + but buffers all writes until a flush is performed. + """ + DEFAULT_BUFFER = 4096 + + def __init__(self, trans, buf_size=DEFAULT_BUFFER): + self._trans = trans + self._wbuf = BytesIO() + self._rbuf = BytesIO(b"") + self._buf_size = buf_size + + def is_open(self): + return self._trans.is_open() + + @asyncio.coroutine + def open(self): + return (yield from self._trans.open()) + + def close(self): + return self._trans.close() + + @asyncio.coroutine + def _read(self, sz): + ret = self._rbuf.read(sz) + if len(ret) != 0: + return ret + + buf = yield from self._trans.read(max(sz, self._buf_size)) + self._rbuf = BytesIO(buf) + return self._rbuf.read(sz) + + @asyncio.coroutine + def read(self, sz): + return (yield from readall(self._read, sz)) + + def write(self, buf): + self._wbuf.write(buf) + + @asyncio.coroutine + def flush(self): + out = self._wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self._wbuf = BytesIO() + self._trans.write(out) + yield from self._trans.flush() + + def getvalue(self): + return self._trans.getvalue() + + +class TAsyncBufferedTransportFactory(object): + def get_transport(self, trans): + return TAsyncBufferedTransport(trans) diff --git a/thriftpy/parser/parser.py b/thriftpy/parser/parser.py index 26c6073..882da53 100644 --- a/thriftpy/parser/parser.py +++ b/thriftpy/parser/parser.py @@ -625,7 +625,7 @@ def _add_thrift_meta(key, val): if not hasattr(thrift, '__thrift_meta__'): meta = collections.defaultdict(list) - setattr(thrift, '__thrift_meta__', meta) + setattr(thrift, '__thrift_meta__', meta) else: meta = getattr(thrift, '__thrift_meta__') diff --git a/thriftpy/rpc.py b/thriftpy/rpc.py index 389ad76..710e200 100644 --- a/thriftpy/rpc.py +++ b/thriftpy/rpc.py @@ -5,6 +5,8 @@ import contextlib import warnings +from thriftpy._compat import PY35 + from thriftpy.protocol import TBinaryProtocolFactory from thriftpy.server import TThreadedServer from thriftpy.thrift import TProcessor, TClient @@ -111,3 +113,10 @@ def client_context(service, host="localhost", port=9090, unix_socket=None, finally: transport.close() + + +if PY35: + from thriftpy.contrib.aio.rpc import ( + make_server as make_aio_server, + make_client as make_aio_client + ) diff --git a/tox.ini b/tox.ini index fefb882..c83ff27 100644 --- a/tox.ini +++ b/tox.ini @@ -14,7 +14,13 @@ deps = tornado toro cython + py26: ordereddict + py35,py36: pytest_asyncio [testenv:flake8] deps = flake8 commands = flake8 . + +[testenv:py26] +setenv = + IS_PY26 = 1