From da4018fb01c91c267f01149ce6012e3f11439aef Mon Sep 17 00:00:00 2001 From: Dominik Andreas Date: Mon, 22 Jan 2024 14:33:00 +0000 Subject: [PATCH] added binary support in dictionaries via base64 encoding --- capnp/lib/capnp.pyx | 43 ++++++++++++++++++++------------ test/test_blob_to_dict_base64.py | 21 ++++++++++++++++ 2 files changed, 48 insertions(+), 16 deletions(-) create mode 100644 test/test_blob_to_dict_base64.py diff --git a/capnp/lib/capnp.pyx b/capnp/lib/capnp.pyx index fdc19d8..f3d7c13 100644 --- a/capnp/lib/capnp.pyx +++ b/capnp/lib/capnp.pyx @@ -26,6 +26,7 @@ import array import asyncio import collections as _collections import contextlib +import base64 import enum as _enum import inspect as _inspect import os as _os @@ -957,17 +958,17 @@ cdef _DynamicStructBuilder temp_msg_b cdef _DynamicStructReader temp_msg_r -cdef _to_dict(msg, bint verbose, bint ordered): +cdef _to_dict(msg, bint verbose, bint ordered, bint encode_bytes_as_base64=False): msg_type = type(msg) if msg_type is _DynamicListBuilder: temp_list_b = msg - return [_to_dict(temp_list_b._get(i), verbose, ordered) for i in range(len(msg))] + return [_to_dict(temp_list_b._get(i), verbose, ordered, encode_bytes_as_base64) for i in range(len(msg))] elif msg_type is _DynamicListReader: temp_list_r = msg - return [_to_dict(temp_list_r._get(i), verbose, ordered) for i in range(len(msg))] + return [_to_dict(temp_list_r._get(i), verbose, ordered, encode_bytes_as_base64) for i in range(len(msg))] elif msg_type is _DynamicResizableListBuilder: temp_list_rb = msg - return [_to_dict(temp_list_rb._get(i), verbose, ordered) for i in range(len(msg))] + return [_to_dict(temp_list_rb._get(i), verbose, ordered, encode_bytes_as_base64) for i in range(len(msg))] if msg_type is _DynamicStructBuilder or isinstance(msg, _Request): temp_msg_b = msg @@ -977,13 +978,13 @@ cdef _to_dict(msg, bint verbose, bint ordered): ret = {} try: which = temp_msg_b.which() - ret[which] = _to_dict(temp_msg_b._get(which), verbose, ordered) + ret[which] = _to_dict(temp_msg_b._get(which), verbose, ordered, encode_bytes_as_base64) except KjException: pass for field in temp_msg_b.schema.non_union_fields: if verbose or temp_msg_b._has(field): - ret[field] = _to_dict(temp_msg_b._get(field), verbose, ordered) + ret[field] = _to_dict(temp_msg_b._get(field), verbose, ordered, encode_bytes_as_base64) return ret elif msg_type is _DynamicStructReader or isinstance(msg, _Response): @@ -994,13 +995,13 @@ cdef _to_dict(msg, bint verbose, bint ordered): ret = {} try: which = temp_msg_r.which() - ret[which] = _to_dict(temp_msg_r._get(which), verbose, ordered) + ret[which] = _to_dict(temp_msg_r._get(which), verbose, ordered, encode_bytes_as_base64) except KjException: pass for field in temp_msg_r.schema.non_union_fields: if verbose or temp_msg_r._has(field): - ret[field] = _to_dict(temp_msg_r._get(field), verbose, ordered) + ret[field] = _to_dict(temp_msg_r._get(field), verbose, ordered, encode_bytes_as_base64) return ret @@ -1010,6 +1011,10 @@ cdef _to_dict(msg, bint verbose, bint ordered): if msg_type is _DynamicEnum: return str(msg) + if encode_bytes_as_base64 and msg_type is bytes: + # encode the message as base64 and return utf-8 string + return base64.b64encode(msg).decode('utf-8') + return msg @@ -1220,8 +1225,8 @@ cdef class _DynamicStructReader: def __repr__(self): return '<%s reader %s>' % (self.schema.node.displayName, strStructReader(self.thisptr).cStr()) - def to_dict(self, verbose=False, ordered=False): - return _to_dict(self, verbose, ordered) + def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False): + return _to_dict(self, verbose, ordered, encode_bytes_as_base64) cpdef as_builder(self, num_first_segment_words=None): """A method for casting this Reader to a Builder @@ -1598,12 +1603,18 @@ cdef class _DynamicStructBuilder: def __repr__(self): return '<%s builder %s>' % (self.schema.node.displayName, strStructBuilder(self.thisptr).cStr()) - def to_dict(self, verbose=False, ordered=False): - return _to_dict(self, verbose, ordered) + def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False): + return _to_dict(self, verbose, ordered, encode_bytes_as_base64) def from_dict(self, dict d): for key, val in d.iteritems(): if key != 'which': + field = self.schema.fields.get(key) + if isinstance(val, str): + dtype = field.proto.slot.type.which() + if dtype == "data": + # decode bytes from utf-8 base64 encoding + val = base64.b64decode(val) try: self._set(key, val) except Exception as e: @@ -1683,8 +1694,8 @@ cdef class _DynamicStructPipeline: # def __repr__(self): # return '<%s reader %s>' % (self.schema.node.displayName, strStructReader(self.thisptr).cStr()) - def to_dict(self, verbose=False, ordered=False): - return _to_dict(self, verbose, ordered) + def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False): + return _to_dict(self, verbose, ordered, encode_bytes_as_base64) cdef class _DynamicOrphan: @@ -2065,8 +2076,8 @@ cdef class _RemotePromise: def __dir__(self): return list(set(self.schema.fieldnames + tuple(dir(self.__class__)))) - def to_dict(self, verbose=False, ordered=False): - return _to_dict(self, verbose, ordered) + def to_dict(self, verbose=False, ordered=False, encode_bytes_as_base64=False): + return _to_dict(self, verbose, ordered, encode_bytes_as_base64) cpdef cancel(self) except +reraise_kj_exception: self.thisptr = Own[RemotePromise]() diff --git a/test/test_blob_to_dict_base64.py b/test/test_blob_to_dict_base64.py new file mode 100644 index 0000000..5e70f7e --- /dev/null +++ b/test/test_blob_to_dict_base64.py @@ -0,0 +1,21 @@ +import os +import capnp +import base64 +import pytest + +this_dir = os.path.dirname(__file__) + + +@pytest.fixture +def blob_schema(): + return capnp.load(os.path.join(this_dir, "blob_test.capnp")) + + +def test_blob_to_dict(blob_schema): + blob_value = b"hello world" + blob = blob_schema.BlobTest(blob=blob_value) + blob_dict = blob.to_dict(encode_bytes_as_base64=True) + assert base64.b64decode(blob_dict["blob"]) == blob_value + msg = blob_schema.BlobTest.new_message() + msg.from_dict(blob_dict) + assert blob.blob == blob_value