Skip to content

Commit

Permalink
Allow reading and writing messages from sockets in async mode
Browse files Browse the repository at this point in the history
  • Loading branch information
LasseBlaauwbroek committed Jun 7, 2023
1 parent d53aa24 commit 1ce01e4
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 1 deletion.
14 changes: 13 additions & 1 deletion capnp/helpers/capabilityHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ void PyAsyncIoStream::shutdownWrite() {
_asyncio_stream_shutdown_write(protocol->obj);
}


class TaskToPromiseAdapter {
public:
TaskToPromiseAdapter(kj::PromiseFulfiller<void>& fulfiller,
Expand All @@ -224,6 +223,19 @@ kj::Promise<void> taskToPromise(kj::Own<PyRefCounter> task, PyObject* callback)
return kj::newAdaptedPromise<void, TaskToPromiseAdapter>(kj::mv(task), callback);
}

::kj::Promise<kj::Own<PyRefCounter>> tryReadMessage(kj::AsyncIoStream& stream, capnp::ReaderOptions opts) {
return capnp::tryReadMessage(stream, opts)
.then([](kj::Maybe<kj::Own<capnp::MessageReader>> maybeReader) -> kj::Promise<kj::Own<PyRefCounter>> {
KJ_IF_MAYBE(reader, maybeReader) {
PyObject* pyreader = make_async_message_reader(kj::mv(*reader));
check_py_error();
return kj::heap<PyRefCounter>(pyreader);
} else {
return kj::heap<PyRefCounter>(Py_None);
}
});
}

void init_capnp_api() {
import_capnp__lib__capnp();
}
3 changes: 3 additions & 0 deletions capnp/helpers/capabilityHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "capnp/dynamic.h"
#include <kj/async-io.h>
#include <capnp/serialize-async.h>
#include <stdexcept>
#include "Python.h"

Expand Down Expand Up @@ -147,4 +148,6 @@ inline kj::Exception makeException(kj::StringPtr message) {

kj::Promise<void> taskToPromise(kj::Own<PyRefCounter> coroutine, PyObject* callback);

::kj::Promise<kj::Own<PyRefCounter>> tryReadMessage(kj::AsyncIoStream& stream, capnp::ReaderOptions opts);

void init_capnp_api();
4 changes: 4 additions & 0 deletions capnp/includes/capnp_cpp.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,7 @@ cdef extern from "capnp/helpers/capabilityHelper.h":
void rejectDisconnected[T](PromiseFulfiller[T]& fulfiller, StringPtr message)
void rejectVoidDisconnected(VoidPromiseFulfiller& fulfiller, StringPtr message)
Exception makeException(StringPtr message)
PyPromise tryReadMessage(AsyncIoStream& stream, ReaderOptions opts)

cdef extern from "capnp/serialize-async.h" namespace " ::capnp":
VoidPromise writeMessage(AsyncIoStream& output, MessageBuilder& builder)
57 changes: 57 additions & 0 deletions capnp/lib/capnp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ cimport cython # noqa: E402

from capnp.helpers.helpers cimport init_capnp_api
from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, Canceler, PyAsyncIoStream, PromiseFulfiller, VoidPromiseFulfiller, makeException
from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, Canceler, PyAsyncIoStream, PromiseFulfiller, VoidPromiseFulfiller, tryReadMessage, writeMessage, makeException
from capnp.includes.schema_cpp cimport (MessageReader,)

from cpython cimport array, Py_buffer, PyObject_CheckBuffer, memoryview, buffer
from cpython.buffer cimport PyBUF_SIMPLE, PyBUF_WRITABLE
Expand Down Expand Up @@ -1307,6 +1309,23 @@ cdef class _DynamicStructBuilder:
_write_message_to_fd(file.fileno(), self._parent)
self._is_written = True

async def write_async(self, _AsyncIoStream stream):
"""Async version of of write().
This is a shortcut for calling capnp._write_message_to_fd(). This can only be called on the
message's root struct.
:type file: AsyncIoStream
:param file: The AsyncIoStream to write the message to
:rtype: void
:Raises: :exc:`KjException` if this isn't the message's root struct.
"""
self._check_write()
await _VoidPromise()._init(writeMessage(deref(stream.thisptr.get()), deref((<_MessageBuilder>self._parent).thisptr)))
self._is_written = True

def write_packed(self, file):
"""Writes the struct's containing message to the given file object in packed binary format.
Expand Down Expand Up @@ -3477,6 +3496,26 @@ class _StructModule(object):
reader = _StreamFdMessageReader(file, traversal_limit_in_words, nesting_limit)
return reader.get_root(self.schema)

async def read_async(self, _AsyncIoStream stream, traversal_limit_in_words=None, nesting_limit=None):
"""Async version of read(). Returns either a message, or None in case of EOF.
:type file: AsyncIoStream
:param file: A AsyncIoStream
:type traversal_limit_in_words: int
:param traversal_limit_in_words: Limits how many total words of data are allowed to be traversed.
Is actually a uint64_t, and values can be up to 2^64-1. Default is 8*1024*1024.
:type nesting_limit: int
:param nesting_limit: Limits how many total words of data are allowed to be traversed. Default is 64.
:rtype: :class:`_DynamicStructReader`"""
cdef schema_cpp.ReaderOptions opts = make_reader_opts(traversal_limit_in_words, nesting_limit)
reader = await _Promise()._init(tryReadMessage(deref(stream.thisptr.get()), opts))
if reader is None:
return
return reader.get_root(self.schema)

def read_multiple(self, file, traversal_limit_in_words=None, nesting_limit=None, skip_copy=False):
"""Returns an iterable, that when traversed will return Readers for messages.
Expand Down Expand Up @@ -4153,6 +4192,24 @@ cdef class _PackedFdMessageReader(_MessageReader):
def __dealloc__(self):
del self.thisptr

cdef class _AsyncMessageReader(_MessageReader):
"""Read a Cap'n Proto message from a AsyncIoStream class.
Do not use directly
"""

def __init__(self):
pass

cdef Own[MessageReader] reader
cdef _init(self, Own[MessageReader] reader):
self.reader = move(reader)
self.thisptr = self.reader.get()
return self

cdef api object make_async_message_reader(Own[MessageReader] reader):
return _AsyncMessageReader()._init(move(reader))


cdef class _MultipleMessageReader:
cdef schema_cpp.FdInputStream * stream
Expand Down
54 changes: 54 additions & 0 deletions examples/async_socket_message_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

import asyncio
import argparse
import capnp

import addressbook_capnp


def parse_args():
parser = argparse.ArgumentParser(
usage="Connects to the Example thread server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT")

return parser.parse_args()


async def writeAddressBook(stream, bob_id):
addresses = addressbook_capnp.AddressBook.new_message()
people = addresses.init("people", 1)

bob = people[0]
bob.id = bob_id
bob.name = "Bob"
bob.email = "[email protected]"
bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567"
bobPhones[0].type = "home"
bobPhones[1].number = "555-7654"
bobPhones[1].type = "work"
bob.employment.unemployed = None

await addresses.write_async(stream)


async def main(host):
host, port = host.split(":")
stream = await capnp.AsyncIoStream.create_connection(host=host, port=port)

await writeAddressBook(stream, 0)

message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message.people[0].name == "Alice"
assert message.people[0].id == 0

await writeAddressBook(stream, 1)


if __name__ == "__main__":
args = parse_args()
asyncio.run(main(args.host))
63 changes: 63 additions & 0 deletions examples/async_socket_message_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3

import argparse
import asyncio

import capnp
import addressbook_capnp


async def writeAddressBook(stream, alice_id):
addresses = addressbook_capnp.AddressBook.new_message()
people = addresses.init("people", 1)

alice = people[0]
alice.id = alice_id
alice.name = "Alice"
alice.email = "[email protected]"
alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212"
alicePhones[0].type = "mobile"
alice.employment.school = "MIT"

await addresses.write_async(stream)


async def new_connection(stream):
message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message.people[0].name == "Bob"
assert message.people[0].id == 0

await writeAddressBook(stream, 0)

message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message.people[0].name == "Bob"
assert message.people[0].id == 1

message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message is None


def parse_args():
parser = argparse.ArgumentParser(
usage="""Runs the server bound to the\
given address/port ADDRESS. """
)

parser.add_argument("address", help="ADDRESS:PORT")

return parser.parse_args()


async def main():
host, port = parse_args().address.split(":")
server = await capnp.AsyncIoStream.create_server(new_connection, host, port)
async with server:
await server.serve_forever()


if __name__ == "__main__":
asyncio.run(main())
7 changes: 7 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,10 @@ def test_async_ssl_calculator_example(cleanup):
server = "async_ssl_calculator_server.py"
client = "async_ssl_calculator_client.py"
run_subprocesses(address, server, client, ipv4_force=False)


def test_async_socket_message_example(cleanup):
address = "{}:36438".format(hostname)
server = "async_socket_message_server.py"
client = "async_socket_message_client.py"
run_subprocesses(address, server, client)

0 comments on commit 1ce01e4

Please sign in to comment.