Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reading and writing messages from sockets in async mode #313

Merged
merged 1 commit into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion capnp/helpers/capabilityHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ void PyAsyncIoStream::shutdownWrite() {
_asyncio_stream_shutdown_write(protocol->obj);
}


class TaskToPromiseAdapter {
public:
TaskToPromiseAdapter(kj::PromiseFulfiller<void>& fulfiller,
Expand All @@ -224,6 +223,19 @@ kj::Promise<void> taskToPromise(kj::Own<PyRefCounter> task, PyObject* callback)
return kj::newAdaptedPromise<void, TaskToPromiseAdapter>(kj::mv(task), callback);
}

::kj::Promise<kj::Own<PyRefCounter>> tryReadMessage(kj::AsyncIoStream& stream, capnp::ReaderOptions opts) {
return capnp::tryReadMessage(stream, opts)
.then([](kj::Maybe<kj::Own<capnp::MessageReader>> maybeReader) -> kj::Promise<kj::Own<PyRefCounter>> {
KJ_IF_MAYBE(reader, maybeReader) {
PyObject* pyreader = make_async_message_reader(kj::mv(*reader));
check_py_error();
return kj::heap<PyRefCounter>(pyreader);
} else {
return kj::heap<PyRefCounter>(Py_None);
}
});
}

void init_capnp_api() {
import_capnp__lib__capnp();
}
3 changes: 3 additions & 0 deletions capnp/helpers/capabilityHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "capnp/dynamic.h"
#include <kj/async-io.h>
#include <capnp/serialize-async.h>
#include <stdexcept>
#include "Python.h"

Expand Down Expand Up @@ -147,4 +148,6 @@ inline kj::Exception makeException(kj::StringPtr message) {

kj::Promise<void> taskToPromise(kj::Own<PyRefCounter> coroutine, PyObject* callback);

::kj::Promise<kj::Own<PyRefCounter>> tryReadMessage(kj::AsyncIoStream& stream, capnp::ReaderOptions opts);

void init_capnp_api();
4 changes: 4 additions & 0 deletions capnp/includes/capnp_cpp.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,7 @@ cdef extern from "capnp/helpers/capabilityHelper.h":
void rejectDisconnected[T](PromiseFulfiller[T]& fulfiller, StringPtr message)
void rejectVoidDisconnected(VoidPromiseFulfiller& fulfiller, StringPtr message)
Exception makeException(StringPtr message)
PyPromise tryReadMessage(AsyncIoStream& stream, ReaderOptions opts)

cdef extern from "capnp/serialize-async.h" namespace " ::capnp":
VoidPromise writeMessage(AsyncIoStream& output, MessageBuilder& builder)
57 changes: 57 additions & 0 deletions capnp/lib/capnp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ cimport cython # noqa: E402

from capnp.helpers.helpers cimport init_capnp_api
from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, Canceler, PyAsyncIoStream, PromiseFulfiller, VoidPromiseFulfiller, makeException
from capnp.includes.capnp_cpp cimport AsyncIoStream, WaitScope, PyPromise, VoidPromise, EventPort, EventLoop, WaitScope, LowLevelAsyncIoProvider, AsyncIoProvider, newAsyncIoProvider, MonotonicClock, Timer, TimerImpl, systemPreciseMonotonicClock, MILLISECONDS, Canceler, PyAsyncIoStream, PromiseFulfiller, VoidPromiseFulfiller, tryReadMessage, writeMessage, makeException
from capnp.includes.schema_cpp cimport (MessageReader,)

from cpython cimport array, Py_buffer, PyObject_CheckBuffer, memoryview, buffer
from cpython.buffer cimport PyBUF_SIMPLE, PyBUF_WRITABLE
Expand Down Expand Up @@ -1307,6 +1309,23 @@ cdef class _DynamicStructBuilder:
_write_message_to_fd(file.fileno(), self._parent)
self._is_written = True

async def write_async(self, _AsyncIoStream stream):
"""Async version of of write().

This is a shortcut for calling capnp._write_message_to_fd(). This can only be called on the
message's root struct.

:type file: AsyncIoStream
:param file: The AsyncIoStream to write the message to

:rtype: void

:Raises: :exc:`KjException` if this isn't the message's root struct.
"""
self._check_write()
await _VoidPromise()._init(writeMessage(deref(stream.thisptr.get()), deref((<_MessageBuilder>self._parent).thisptr)))
self._is_written = True

def write_packed(self, file):
"""Writes the struct's containing message to the given file object in packed binary format.

Expand Down Expand Up @@ -3477,6 +3496,26 @@ class _StructModule(object):
reader = _StreamFdMessageReader(file, traversal_limit_in_words, nesting_limit)
return reader.get_root(self.schema)

async def read_async(self, _AsyncIoStream stream, traversal_limit_in_words=None, nesting_limit=None):
"""Async version of read(). Returns either a message, or None in case of EOF.

:type file: AsyncIoStream
:param file: A AsyncIoStream

:type traversal_limit_in_words: int
:param traversal_limit_in_words: Limits how many total words of data are allowed to be traversed.
Is actually a uint64_t, and values can be up to 2^64-1. Default is 8*1024*1024.

:type nesting_limit: int
:param nesting_limit: Limits how many total words of data are allowed to be traversed. Default is 64.

:rtype: :class:`_DynamicStructReader`"""
cdef schema_cpp.ReaderOptions opts = make_reader_opts(traversal_limit_in_words, nesting_limit)
reader = await _Promise()._init(tryReadMessage(deref(stream.thisptr.get()), opts))
if reader is None:
return
return reader.get_root(self.schema)

def read_multiple(self, file, traversal_limit_in_words=None, nesting_limit=None, skip_copy=False):
"""Returns an iterable, that when traversed will return Readers for messages.

Expand Down Expand Up @@ -4153,6 +4192,24 @@ cdef class _PackedFdMessageReader(_MessageReader):
def __dealloc__(self):
del self.thisptr

cdef class _AsyncMessageReader(_MessageReader):
"""Read a Cap'n Proto message from a AsyncIoStream class.

Do not use directly
"""

def __init__(self):
pass

cdef Own[MessageReader] reader
cdef _init(self, Own[MessageReader] reader):
self.reader = move(reader)
self.thisptr = self.reader.get()
return self

cdef api object make_async_message_reader(Own[MessageReader] reader):
return _AsyncMessageReader()._init(move(reader))


cdef class _MultipleMessageReader:
cdef schema_cpp.FdInputStream * stream
Expand Down
54 changes: 54 additions & 0 deletions examples/async_socket_message_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

import asyncio
import argparse
import capnp

import addressbook_capnp


def parse_args():
parser = argparse.ArgumentParser(
usage="Connects to the Example thread server \
at the given address and does some RPCs"
)
parser.add_argument("host", help="HOST:PORT")

return parser.parse_args()


async def writeAddressBook(stream, bob_id):
addresses = addressbook_capnp.AddressBook.new_message()
people = addresses.init("people", 1)

bob = people[0]
bob.id = bob_id
bob.name = "Bob"
bob.email = "[email protected]"
bobPhones = bob.init("phones", 2)
bobPhones[0].number = "555-4567"
bobPhones[0].type = "home"
bobPhones[1].number = "555-7654"
bobPhones[1].type = "work"
bob.employment.unemployed = None

await addresses.write_async(stream)


async def main(host):
host, port = host.split(":")
stream = await capnp.AsyncIoStream.create_connection(host=host, port=port)

await writeAddressBook(stream, 0)

message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message.people[0].name == "Alice"
assert message.people[0].id == 0

await writeAddressBook(stream, 1)


if __name__ == "__main__":
args = parse_args()
asyncio.run(main(args.host))
63 changes: 63 additions & 0 deletions examples/async_socket_message_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3

import argparse
import asyncio

import capnp
import addressbook_capnp


async def writeAddressBook(stream, alice_id):
addresses = addressbook_capnp.AddressBook.new_message()
people = addresses.init("people", 1)

alice = people[0]
alice.id = alice_id
alice.name = "Alice"
alice.email = "[email protected]"
alicePhones = alice.init("phones", 1)
alicePhones[0].number = "555-1212"
alicePhones[0].type = "mobile"
alice.employment.school = "MIT"

await addresses.write_async(stream)


async def new_connection(stream):
message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message.people[0].name == "Bob"
assert message.people[0].id == 0

await writeAddressBook(stream, 0)

message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message.people[0].name == "Bob"
assert message.people[0].id == 1

message = await addressbook_capnp.AddressBook.read_async(stream)
print(message)
assert message is None


def parse_args():
parser = argparse.ArgumentParser(
usage="""Runs the server bound to the\
given address/port ADDRESS. """
)

parser.add_argument("address", help="ADDRESS:PORT")

return parser.parse_args()


async def main():
host, port = parse_args().address.split(":")
server = await capnp.AsyncIoStream.create_server(new_connection, host, port)
async with server:
await server.serve_forever()


if __name__ == "__main__":
asyncio.run(main())
7 changes: 7 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,10 @@ def test_async_ssl_calculator_example(cleanup):
server = "async_ssl_calculator_server.py"
client = "async_ssl_calculator_client.py"
run_subprocesses(address, server, client, ipv4_force=False)


def test_async_socket_message_example(cleanup):
address = "{}:36438".format(hostname)
server = "async_socket_message_server.py"
client = "async_socket_message_client.py"
run_subprocesses(address, server, client)