Skip to content

Commit

Permalink
add 'strict_decode' to cybin protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
aisk committed Nov 14, 2023
1 parent 93556bd commit 5e64ee3
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
8 changes: 8 additions & 0 deletions tests/test_protocol_cybinary.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def test_read_binary():
b, TType.STRING, decode_response=False)


def test_strict_decode():
bs = TCyMemoryBuffer(b"\x00\x00\x00\x0c\x00" # there is a redundant '\x00'
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c")
with pytest.raises(UnicodeDecodeError):
proto.read_val(bs, TType.STRING, decode_response=True,
strict_decode=True)


def test_write_message_begin():
trans = TCyMemoryBuffer()
b = proto.TCyBinaryProtocol(trans)
Expand Down
41 changes: 26 additions & 15 deletions thriftpy2/protocol/cybin/cybin.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ cdef inline write_dict(CyTransportBase buf, object val, spec):
c_write_val(buf, v_type, v, v_spec)


cdef inline read_struct(CyTransportBase buf, obj, decode_response=True):
cdef inline read_struct(CyTransportBase buf, obj, decode_response=True,
strict_decode=False):
cdef dict field_specs = obj.thrift_spec
cdef int fid
cdef TType field_type, ttype
Expand Down Expand Up @@ -199,7 +200,8 @@ cdef inline read_struct(CyTransportBase buf, obj, decode_response=True):
else:
spec = field_spec[2]

setattr(obj, name, c_read_val(buf, ttype, spec, decode_response))
setattr(obj, name, c_read_val(buf, ttype, spec, decode_response,
strict_decode))

return obj

Expand Down Expand Up @@ -251,16 +253,19 @@ cdef inline c_read_binary(CyTransportBase buf, int32_t size):
return py_data


cdef inline c_read_string(CyTransportBase buf, int32_t size):
cdef inline c_read_string(CyTransportBase buf, int32_t size,
strict_decode=False):
py_data = c_read_binary(buf, size)
try:
return (<char *>py_data)[:size].decode("utf-8")
except: # noqa
if strict_decode:
raise
return py_data


cdef c_read_val(CyTransportBase buf, TType ttype, spec=None,
decode_response=True):
decode_response=True, strict_decode=False):
cdef int size
cdef int64_t n
cdef TType v_type, k_type, orig_type, orig_key_type
Expand Down Expand Up @@ -291,7 +296,7 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None,
elif ttype == T_STRING:
size = read_i32(buf)
if decode_response:
return c_read_string(buf, size)
return c_read_string(buf, size, strict_decode)
else:
return c_read_binary(buf, size)

Expand All @@ -311,7 +316,7 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None,
skip(buf, orig_type)
return []

return [c_read_val(buf, v_type, v_spec, decode_response)
return [c_read_val(buf, v_type, v_spec, decode_response, strict_decode)
for _ in range(size)]

elif ttype == T_MAP:
Expand Down Expand Up @@ -345,13 +350,13 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None,
return {}

return {
c_read_val(buf, k_type, k_spec, decode_response):
c_read_val(buf, v_type, v_spec, decode_response)
c_read_val(buf, k_type, k_spec, decode_response, strict_decode):
c_read_val(buf, v_type, v_spec, decode_response, strict_decode)
for _ in range(size)
}

elif ttype == T_STRUCT:
return read_struct(buf, spec(), decode_response)
return read_struct(buf, spec(), decode_response, strict_decode)


cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None):
Expand Down Expand Up @@ -432,8 +437,9 @@ cpdef skip(CyTransportBase buf, TType ttype):
skip(buf, f_type)


def read_val(CyTransportBase buf, TType ttype, decode_response=True):
return c_read_val(buf, ttype, None, decode_response)
def read_val(CyTransportBase buf, TType ttype, decode_response=True,
strict_decode=False):
return c_read_val(buf, ttype, None, decode_response, strict_decode)


def write_val(CyTransportBase buf, TType ttype, val, spec=None):
Expand All @@ -445,13 +451,15 @@ cdef class TCyBinaryProtocol(object):
cdef public bool strict_read
cdef public bool strict_write
cdef public bool decode_response
cdef public bool strict_decode

def __init__(self, trans, strict_read=True, strict_write=True,
decode_response=True):
decode_response=True, strict_decode=False):
self.trans = trans
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response
self.strict_decode = strict_decode

def skip(self, ttype):
skip(self.trans, <TType>(ttype))
Expand Down Expand Up @@ -498,7 +506,8 @@ cdef class TCyBinaryProtocol(object):

def read_struct(self, obj):
try:
return read_struct(self.trans, obj, self.decode_response)
return read_struct(self.trans, obj, self.decode_response,
self.strict_decode)
except Exception:
self.trans.clean()
raise
Expand All @@ -513,11 +522,13 @@ cdef class TCyBinaryProtocol(object):

class TCyBinaryProtocolFactory(object):
def __init__(self, strict_read=True, strict_write=True,
decode_response=True):
decode_response=True, strict_decode=False):
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response
self.strict_decode = strict_decode

def get_protocol(self, trans):
return TCyBinaryProtocol(
trans, self.strict_read, self.strict_write, self.decode_response)
trans, self.strict_read, self.strict_write, self.decode_response,
self.strict_decode)

0 comments on commit 5e64ee3

Please sign in to comment.