From e7f59976eecf0fbecc20f2478863323ad50cac0b Mon Sep 17 00:00:00 2001 From: Lasse Blaauwbroek Date: Thu, 20 Apr 2023 23:58:53 +0200 Subject: [PATCH] Allow reading and writing messages from sockets in `async` mode --- capnp/helpers/capabilityHelper.cpp | 13 +++++ capnp/helpers/capabilityHelper.h | 3 ++ capnp/includes/capnp_cpp.pxd | 4 ++ capnp/lib/capnp.pyx | 58 ++++++++++++++++++++++- examples/async_socket_message_client.py | 54 +++++++++++++++++++++ examples/async_socket_message_server.py | 63 +++++++++++++++++++++++++ test/test_examples.py | 7 +++ 7 files changed, 201 insertions(+), 1 deletion(-) create mode 100644 examples/async_socket_message_client.py create mode 100644 examples/async_socket_message_server.py diff --git a/capnp/helpers/capabilityHelper.cpp b/capnp/helpers/capabilityHelper.cpp index 92c0444a..d41bf1ce 100644 --- a/capnp/helpers/capabilityHelper.cpp +++ b/capnp/helpers/capabilityHelper.cpp @@ -203,6 +203,19 @@ void PyAsyncIoStream::shutdownWrite() { _asyncio_stream_shutdown_write(protocol->obj); } +::kj::Promise> tryReadMessage(kj::AsyncIoStream& stream, capnp::ReaderOptions opts) { + return capnp::tryReadMessage(stream, opts) + .then([](kj::Maybe> maybeReader) -> kj::Promise> { + KJ_IF_MAYBE(reader, maybeReader) { + PyObject* pyreader = make_async_message_reader(kj::mv(*reader)); + check_py_error(); + return kj::heap(pyreader); + } else { + return kj::heap(Py_None); + } + }); +} + void init_capnp_api() { import_capnp__lib__capnp(); } diff --git a/capnp/helpers/capabilityHelper.h b/capnp/helpers/capabilityHelper.h index d41bad9d..d8eca95e 100644 --- a/capnp/helpers/capabilityHelper.h +++ b/capnp/helpers/capabilityHelper.h @@ -2,6 +2,7 @@ #include "capnp/dynamic.h" #include +#include #include #include "Python.h" @@ -141,4 +142,6 @@ inline void rejectVoidDisconnected(kj::PromiseFulfiller& fulfiller, kj::St fulfiller.reject(KJ_EXCEPTION(DISCONNECTED, message)); } +::kj::Promise> tryReadMessage(kj::AsyncIoStream& stream, capnp::ReaderOptions opts); + void init_capnp_api(); diff --git a/capnp/includes/capnp_cpp.pxd b/capnp/includes/capnp_cpp.pxd index 346533c0..3e51e9f3 100644 --- a/capnp/includes/capnp_cpp.pxd +++ b/capnp/includes/capnp_cpp.pxd @@ -564,3 +564,7 @@ cdef extern from "capnp/helpers/capabilityHelper.h": PyAsyncIoStream(PyObject* thisptr) void rejectDisconnected[T](PromiseFulfiller[T]& fulfiller, StringPtr message) void rejectVoidDisconnected(VoidPromiseFulfiller& fulfiller, StringPtr message) + PyPromise tryReadMessage(AsyncIoStream& stream, ReaderOptions opts) + +cdef extern from "capnp/serialize-async.h" namespace " ::capnp": + VoidPromise writeMessage(AsyncIoStream& output, MessageBuilder& builder) diff --git a/capnp/lib/capnp.pyx b/capnp/lib/capnp.pyx index c8c758df..a8ea6f55 100644 --- a/capnp/lib/capnp.pyx +++ b/capnp/lib/capnp.pyx @@ -10,7 +10,8 @@ cimport cython # noqa: E402 from capnp.helpers.helpers cimport init_capnp_api -from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, Canceler, PyAsyncIoStream, PromiseFulfiller, VoidPromiseFulfiller +from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, Canceler, PyAsyncIoStream, PromiseFulfiller, VoidPromiseFulfiller, tryReadMessage, writeMessage +from capnp.includes.schema_cpp cimport (MessageReader,) from cpython cimport array, Py_buffer, PyObject_CheckBuffer, memoryview, buffer from cpython.buffer cimport PyBUF_SIMPLE, PyBUF_WRITABLE @@ -1272,6 +1273,23 @@ cdef class _DynamicStructBuilder: _write_message_to_fd(file.fileno(), self._parent) self._is_written = True + async def write_async(self, _AsyncIoStream stream): + """Async version of of write(). + + This is a shortcut for calling capnp._write_message_to_fd(). This can only be called on the + message's root struct. + + :type file: AsyncIoStream + :param file: The AsyncIoStream to write the message to + + :rtype: void + + :Raises: :exc:`KjException` if this isn't the message's root struct. + """ + self._check_write() + await _VoidPromise()._init(writeMessage(deref(stream.thisptr.get()), deref((<_MessageBuilder>self._parent).thisptr))) + self._is_written = True + def write_packed(self, file): """Writes the struct's containing message to the given file object in packed binary format. @@ -3442,6 +3460,26 @@ class _StructModule(object): reader = _StreamFdMessageReader(file, traversal_limit_in_words, nesting_limit) return reader.get_root(self.schema) + async def read_async(self, _AsyncIoStream stream, traversal_limit_in_words=None, nesting_limit=None): + """Async version of read(). Returns either a message, or None in case of EOF. + + :type file: AsyncIoStream + :param file: A AsyncIoStream + + :type traversal_limit_in_words: int + :param traversal_limit_in_words: Limits how many total words of data are allowed to be traversed. + Is actually a uint64_t, and values can be up to 2^64-1. Default is 8*1024*1024. + + :type nesting_limit: int + :param nesting_limit: Limits how many total words of data are allowed to be traversed. Default is 64. + + :rtype: :class:`_DynamicStructReader`""" + cdef schema_cpp.ReaderOptions opts = make_reader_opts(traversal_limit_in_words, nesting_limit) + reader = await _Promise()._init(tryReadMessage(deref(stream.thisptr.get()), opts)) + if reader is None: + return + return reader.get_root(self.schema) + def read_multiple(self, file, traversal_limit_in_words=None, nesting_limit=None, skip_copy=False): """Returns an iterable, that when traversed will return Readers for messages. @@ -4118,6 +4156,24 @@ cdef class _PackedFdMessageReader(_MessageReader): def __dealloc__(self): del self.thisptr +cdef class _AsyncMessageReader(_MessageReader): + """Read a Cap'n Proto message from a AsyncIoStream class. + + Do not use directly + """ + + def __init__(self): + pass + + cdef Own[MessageReader] reader + cdef _init(self, Own[MessageReader] reader): + self.reader = move(reader) + self.thisptr = self.reader.get() + return self + +cdef api object make_async_message_reader(Own[MessageReader] reader): + return _AsyncMessageReader()._init(move(reader)) + cdef class _MultipleMessageReader: cdef schema_cpp.FdInputStream * stream diff --git a/examples/async_socket_message_client.py b/examples/async_socket_message_client.py new file mode 100644 index 00000000..ef08f338 --- /dev/null +++ b/examples/async_socket_message_client.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +import asyncio +import argparse +import capnp + +import addressbook_capnp + + +def parse_args(): + parser = argparse.ArgumentParser( + usage="Connects to the Example thread server \ +at the given address and does some RPCs" + ) + parser.add_argument("host", help="HOST:PORT") + + return parser.parse_args() + + +async def writeAddressBook(stream, bob_id): + addresses = addressbook_capnp.AddressBook.new_message() + people = addresses.init("people", 1) + + bob = people[0] + bob.id = bob_id + bob.name = "Bob" + bob.email = "bob@example.com" + bobPhones = bob.init("phones", 2) + bobPhones[0].number = "555-4567" + bobPhones[0].type = "home" + bobPhones[1].number = "555-7654" + bobPhones[1].type = "work" + bob.employment.unemployed = None + + await addresses.write_async(stream) + + +async def main(host): + host, port = host.split(":") + stream = await capnp.AsyncIoStream.create_connection(host=host, port=port) + + await writeAddressBook(stream, 0) + + message = await addressbook_capnp.AddressBook.read_async(stream) + print(message) + assert message.people[0].name == "Alice" + assert message.people[0].id == 0 + + await writeAddressBook(stream, 1) + + +if __name__ == "__main__": + args = parse_args() + asyncio.run(main(args.host)) diff --git a/examples/async_socket_message_server.py b/examples/async_socket_message_server.py new file mode 100644 index 00000000..e261d7e8 --- /dev/null +++ b/examples/async_socket_message_server.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +import argparse +import asyncio + +import capnp +import addressbook_capnp + + +async def writeAddressBook(stream, alice_id): + addresses = addressbook_capnp.AddressBook.new_message() + people = addresses.init("people", 1) + + alice = people[0] + alice.id = alice_id + alice.name = "Alice" + alice.email = "alice@example.com" + alicePhones = alice.init("phones", 1) + alicePhones[0].number = "555-1212" + alicePhones[0].type = "mobile" + alice.employment.school = "MIT" + + await addresses.write_async(stream) + + +async def new_connection(stream): + message = await addressbook_capnp.AddressBook.read_async(stream) + print(message) + assert message.people[0].name == "Bob" + assert message.people[0].id == 0 + + await writeAddressBook(stream, 0) + + message = await addressbook_capnp.AddressBook.read_async(stream) + print(message) + assert message.people[0].name == "Bob" + assert message.people[0].id == 1 + + message = await addressbook_capnp.AddressBook.read_async(stream) + print(message) + assert message is None + + +def parse_args(): + parser = argparse.ArgumentParser( + usage="""Runs the server bound to the\ +given address/port ADDRESS. """ + ) + + parser.add_argument("address", help="ADDRESS:PORT") + + return parser.parse_args() + + +async def main(): + host, port = parse_args().address.split(":") + server = await capnp.AsyncIoStream.create_server(new_connection, host, port) + async with server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/test/test_examples.py b/test/test_examples.py index b585563a..40d496b3 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -189,3 +189,10 @@ def test_async_ssl_calculator_example(cleanup): server = "async_ssl_calculator_server.py" client = "async_ssl_calculator_client.py" run_subprocesses(address, server, client, ipv4_force=False) + + +def test_async_socket_message_example(cleanup): + address = "{}:36438".format(hostname) + server = "async_socket_message_server.py" + client = "async_socket_message_client.py" + run_subprocesses(address, server, client)