diff --git a/docs/reference.rst b/docs/reference.rst index 06d9034..f7b1847 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -218,6 +218,34 @@ not a key exists depends on the version of libmemcached and memcached used. + .. method:: serialize(value) -> bytestring, flag + + Serialize a Python value to bytes *bytestring* and an integer *flag* field + for storage in memcached. The default implementation has special cases + for bytes, ints/longs, and bools, and falls back to pickle for all other + objects. Override this method to use a custom serialization format, or + otherwise modify the behavior. + + *flag* is exposed by libmemcached. In this context, it adds flexibility + in terms of encoding schemes: for example, objects *a* and *b* of + different types may coincidentally encode to the same *bytestring*, + just so long as they encode with different values of *flag*. If distinct + values always encode to different byte strings (for example, when + serializing all values with pickle), *flag* can simply be set to a + constant. + + .. method:: deserialize(bytestring, flag) -> value + + Deserialize *bytestring*, stored with *flag*, back to a Python object. + Override this method (in concert with ``serialize``) to use a custom + serialization format, or otherwise modify the behavior. + + Raise ``CacheMiss`` in order to simulate a cache miss for the relevant + key, i.e., ``get`` will return None and ``get_multi`` will omit the key + from the returned mapping. This can be used to recover gracefully from + version skew (e.g., retrieving a value that was pickled by a different, + incompatible code version). + .. data:: behaviors The behaviors used by the underlying libmemcached object. See diff --git a/src/_pylibmcmodule.c b/src/_pylibmcmodule.c index a7b3291..53dac11 100644 --- a/src/_pylibmcmodule.c +++ b/src/_pylibmcmodule.c @@ -84,6 +84,9 @@ #define PyObject_Bytes PyObject_Str #endif +/* Cache the values of {cP,p}ickle.{load,dump}s */ +static PyObject *_PylibMC_pickle_loads = NULL; +static PyObject *_PylibMC_pickle_dumps = NULL; /* {{{ Type methods */ static PylibMC_Client *PylibMC_ClientType_new(PyTypeObject *type, @@ -511,11 +514,9 @@ static void _PylibMC_cleanup_str_key_mapping(PyObject *key_str_map) { } /* }}} */ -static PyObject *_PylibMC_parse_memcached_value(char *value, size_t size, - uint32_t flags) { +static PyObject *_PylibMC_parse_memcached_value(PylibMC_Client *self, + char *value, size_t size, uint32_t flags) { PyObject *retval = NULL; - PyObject *tmp = NULL; - uint32_t dtype = flags & PYLIBMC_FLAG_TYPES; #if USE_ZLIB PyObject *inflated = NULL; @@ -571,52 +572,77 @@ static PyObject *_PylibMC_parse_memcached_value(char *value, size_t size, } #endif +#if PY_MAJOR_VERSION >= 3 + retval = PyObject_CallMethod((PyObject *)self, "deserialize", "y#I", value, size, (unsigned int) flags); +#else + retval = PyObject_CallMethod((PyObject *)self, "deserialize", "s#I", value, size, (unsigned int) flags); +#endif + +#if USE_ZLIB + Py_XDECREF(inflated); +#endif + + + return retval; +} + +PyObject *PylibMC_Client_deserialize(PylibMC_Client *self, PyObject *args) { + PyObject *retval = NULL; + + PyObject *value; + unsigned int flags; + if (!PyArg_ParseTuple(args, "OI", &value, &flags)) { + return NULL; + } + uint32_t dtype = ((uint32_t) flags) & PYLIBMC_FLAG_TYPES; + switch (dtype) { case PYLIBMC_FLAG_PICKLE: - retval = _PylibMC_Unpickle(value, size); + retval = _PylibMC_Unpickle(value); break; case PYLIBMC_FLAG_INTEGER: case PYLIBMC_FLAG_LONG: case PYLIBMC_FLAG_BOOL: - /* PyInt_FromString doesn't take a length param and we're - not NULL-terminated, so we'll have to make an - intermediate Python string out of it */ - tmp = PyBytes_FromStringAndSize(value, size); - if(tmp == NULL) { - goto cleanup; - } - retval = PyLong_FromString(PyBytes_AS_STRING(tmp), NULL, 10); - if(retval != NULL && dtype == PYLIBMC_FLAG_BOOL) { - Py_DECREF(tmp); - tmp = retval; - retval = PyBool_FromLong(PyLong_AS_LONG(tmp)); + retval = PyLong_FromString(PyBytes_AS_STRING(value), NULL, 10); + if (retval != NULL && dtype == PYLIBMC_FLAG_BOOL) { + PyObject *bool_retval = PyBool_FromLong(PyLong_AS_LONG(retval)); + Py_DECREF(retval); + retval = bool_retval; } break; case PYLIBMC_FLAG_NONE: - retval = PyBytes_FromStringAndSize(value, (Py_ssize_t)size); + /* acquire an additional reference for parity */ + retval = value; + Py_INCREF(retval); break; default: PyErr_Format(PylibMCExc_Error, - "unknown memcached key flags %u", flags); + "unknown memcached key flags %u", dtype); } -cleanup: - -#if USE_ZLIB - Py_XDECREF(inflated); -#endif - - Py_XDECREF(tmp); - return retval; } -static PyObject *_PylibMC_parse_memcached_result(memcached_result_st *res) { - return _PylibMC_parse_memcached_value((char *)memcached_result_value(res), +static PyObject *_PylibMC_parse_memcached_result(PylibMC_Client *self, memcached_result_st *res) { + return _PylibMC_parse_memcached_value( + self, + (char *)memcached_result_value(res), memcached_result_length(res), memcached_result_flags(res)); } +/* Helper to call after _PylibMC_parse_memcached_value; + determines whether the deserialized value should be ignored + and treated as a miss. +*/ +static int _PylibMC_cache_miss_simulated(PyObject *r) { + if (r == NULL && PyErr_Occurred() && PyErr_ExceptionMatches(PylibMCExc_CacheMiss)) { + PyErr_Clear(); + return 1; + } + return 0; +} + static PyObject *PylibMC_Client_get(PylibMC_Client *self, PyObject *arg) { char *mc_val; size_t val_size; @@ -638,8 +664,13 @@ static PyObject *PylibMC_Client_get(PylibMC_Client *self, PyObject *arg) { Py_DECREF(arg); if (mc_val != NULL) { - PyObject *r = _PylibMC_parse_memcached_value(mc_val, val_size, flags); + PyObject *r = _PylibMC_parse_memcached_value(self, mc_val, val_size, flags); free(mc_val); + if (_PylibMC_cache_miss_simulated(r)) { + /* Since python-memcache returns None when the key doesn't exist, + * so shall we. */ + Py_RETURN_NONE; + } return r; } else if (error == MEMCACHED_SUCCESS) { /* This happens for empty values, and so we fake an empty string. */ @@ -687,25 +718,37 @@ static PyObject *PylibMC_Client_gets(PylibMC_Client *self, PyObject *arg) { Py_END_ALLOW_THREADS; + int miss = 0; + int fail = 0; if (rc == MEMCACHED_SUCCESS && res != NULL) { - ret = Py_BuildValue("(NL)", - _PylibMC_parse_memcached_result(res), - memcached_result_cas(res)); + PyObject *val = _PylibMC_parse_memcached_result(self, res); + if (_PylibMC_cache_miss_simulated(val)) { + miss = 1; + } else { + ret = Py_BuildValue("(NL)", + val, + memcached_result_cas(res)); + } /* we have to fetch the last result from the mget cursor */ if (NULL != memcached_fetch_result(self->mc, NULL, &rc)) { memcached_quit(self->mc); Py_DECREF(ret); ret = NULL; + fail = 1; PyErr_SetString(PyExc_RuntimeError, "fetch not done"); } } else if (rc == MEMCACHED_END || rc == MEMCACHED_NOTFOUND) { - /* Key not found => (None, None) */ - ret = Py_BuildValue("(OO)", Py_None, Py_None); + miss = 1; } else { ret = PylibMC_ErrFromMemcached(self, "memcached_gets", rc); } + if (miss && !fail) { + /* Key not found => (None, None) */ + ret = Py_BuildValue("(OO)", Py_None, Py_None); + } + if (res != NULL) { memcached_result_free(res); } @@ -781,7 +824,7 @@ static PyObject *_PylibMC_RunSetCommandSingle(PylibMC_Client *self, */ key = PyBytes_FromStringAndSize(key_raw, keylen); - success = _PylibMC_SerializeValue(key, NULL, value, time, &serialized); + success = _PylibMC_SerializeValue(self, key, NULL, value, time, &serialized); if (!success) goto cleanup; @@ -865,7 +908,7 @@ static PyObject *_PylibMC_RunSetCommandMulti(PylibMC_Client *self, } for (i = 0, idx = 0; PyDict_Next(keys, &i, &curr_key, &curr_value); idx++) { - int success = _PylibMC_SerializeValue(curr_key, key_prefix, + int success = _PylibMC_SerializeValue(self, curr_key, key_prefix, curr_value, time, &serialized[idx]); @@ -946,7 +989,7 @@ static PyObject *_PylibMC_RunCasCommand(PylibMC_Client *self, /* TODO: because it's RunSetCommand that does the zlib compression, cas can't currently use compressed values. */ - success = _PylibMC_SerializeValue(key, NULL, value, time, &mset); + success = _PylibMC_SerializeValue(self, key, NULL, value, time, &mset); if (!success || PyErr_Occurred() != NULL) { goto cleanup; @@ -992,19 +1035,19 @@ static void _PylibMC_FreeMset(pylibmc_mset *mset) { mset->value_obj = NULL; } -static int _PylibMC_SerializeValue(PyObject* key_obj, +static int _PylibMC_SerializeValue(PylibMC_Client *self, + PyObject* key_obj, PyObject* key_prefix, PyObject* value_obj, time_t time, pylibmc_mset* serialized) { - PyObject* store_val = NULL; /* first zero the whole structure out */ memset((void *)serialized, 0x0, sizeof(pylibmc_mset)); serialized->time = time; serialized->success = false; - serialized->flags = PYLIBMC_FLAG_NONE; + serialized->value_obj = NULL; if (!_key_normalized_obj(&key_obj)) { return false; @@ -1058,7 +1101,56 @@ static int _PylibMC_SerializeValue(PyObject* key_obj, serialized->prefixed_key_obj = prefixed_key_obj; } - /* Build store_val, a Python str/bytes object */ + /* Build serialized->value_obj, a Python str/bytes object. */ + PyObject *serval_and_flags = PyObject_CallMethod((PyObject *)self, "serialize", "(O)", value_obj); + if (serval_and_flags == NULL) { + return false; + } + + if (PyTuple_Check(serval_and_flags)) { + PyObject *flags_obj = PyTuple_GetItem(serval_and_flags, 1); + if (flags_obj != NULL) { +#if PY_MAJOR_VERSION >= 3 + if (PyLong_Check(flags_obj)) { + serialized->flags = (uint32_t) PyLong_AsLong(flags_obj); + serialized->value_obj = PyTuple_GetItem(serval_and_flags, 0); + } +#else + if (PyInt_Check(flags_obj)) { + serialized->flags = (uint32_t) PyInt_AsLong(flags_obj); + serialized->value_obj = PyTuple_GetItem(serval_and_flags, 0); + } +#endif + } + } + + if (serialized->value_obj == NULL) { + /* PyErr_SetObject(PyExc_ValueError, serval_and_flags); */ + PyErr_SetString(PyExc_ValueError, "serialize() must return (bytes, flags)"); + Py_DECREF(serval_and_flags); + return false; + } else { + /* We're getting rid of serval_and_flags, which owns the only new + reference to value_obj. However, we can't deallocate value_obj + until we're done with value and value_len (after the set + operation). Therefore, take possession of a new reference to it + before cleaning up the tuple: */ + Py_INCREF(serialized->value_obj); + Py_DECREF(serval_and_flags); + } + + + if (PyBytes_AsStringAndSize(serialized->value_obj, &serialized->value, + &serialized->value_len) == -1) { + return false; + } + + return true; +} + +static PyObject *PylibMC_Client_serialize(PylibMC_Client *self, PyObject *value_obj) { + uint32_t flags = PYLIBMC_FLAG_NONE; + PyObject *store_val = NULL; if (PyBytes_Check(value_obj)) { /* Make store_val an owned reference */ @@ -1066,24 +1158,24 @@ static int _PylibMC_SerializeValue(PyObject* key_obj, Py_INCREF(store_val); #if PY_MAJOR_VERSION >= 3 } else if (PyBool_Check(value_obj)) { - serialized->flags |= PYLIBMC_FLAG_BOOL; + flags |= PYLIBMC_FLAG_BOOL; store_val = PyBytes_FromFormat("%ld", PyLong_AsLong(value_obj)); } else if (PyLong_Check(value_obj)) { - serialized->flags |= PYLIBMC_FLAG_LONG; + flags |= PYLIBMC_FLAG_LONG; store_val = PyBytes_FromFormat("%ld", PyLong_AsLong(value_obj)); #else } else if (PyBool_Check(value_obj)) { - serialized->flags |= PYLIBMC_FLAG_BOOL; + flags |= PYLIBMC_FLAG_BOOL; PyObject* tmp = PyNumber_Long(value_obj); store_val = PyObject_Bytes(tmp); Py_DECREF(tmp); } else if (PyInt_Check(value_obj)) { - serialized->flags |= PYLIBMC_FLAG_INTEGER; + flags |= PYLIBMC_FLAG_INTEGER; PyObject* tmp = PyNumber_Int(value_obj); store_val = PyObject_Bytes(tmp); Py_DECREF(tmp); } else if (PyLong_Check(value_obj)) { - serialized->flags |= PYLIBMC_FLAG_LONG; + flags |= PYLIBMC_FLAG_LONG; PyObject* tmp = PyNumber_Long(value_obj); store_val = PyObject_Bytes(tmp); Py_DECREF(tmp); @@ -1091,25 +1183,17 @@ static int _PylibMC_SerializeValue(PyObject* key_obj, } else if (value_obj != NULL) { /* we have no idea what it is, so we'll store it pickled */ Py_INCREF(value_obj); - serialized->flags |= PYLIBMC_FLAG_PICKLE; + flags |= PYLIBMC_FLAG_PICKLE; store_val = _PylibMC_Pickle(value_obj); Py_DECREF(value_obj); } if (store_val == NULL) { - return false; - } - - /* store_val is an owned reference; released when we're done with value and - * value_len (i.e. not here.) */ - serialized->value_obj = store_val; - - if (PyBytes_AsStringAndSize(store_val, &serialized->value, - &serialized->value_len) == -1) { - return false; + return NULL; } - return true; + /* we own a reference to store_val. "give" it to the tuple return value: */ + return Py_BuildValue("(NI)", store_val, flags); } /* {{{ Set commands (set, replace, add, prepend, append) */ @@ -1577,7 +1661,7 @@ static PyObject *PylibMC_Client_get_multi( PyObject *key_str_map = NULL; PyObject *temp_key_obj; size_t *key_lens; - Py_ssize_t nkeys, orig_nkeys; + Py_ssize_t nkeys = 0, orig_nkeys = 0; size_t nresults = 0; memcached_return rc; pylibmc_mget_req req; @@ -1588,34 +1672,31 @@ static PyObject *PylibMC_Client_get_multi( &key_seq, &prefix, &prefix_len)) return NULL; - if ((orig_nkeys = nkeys = PySequence_Length(key_seq)) == -1) + if ((orig_nkeys = PySequence_Length(key_seq)) == -1) return NULL; /* Populate keys and key_lens. */ - keys = PyMem_New(char *, nkeys); - key_lens = PyMem_New(size_t, (size_t) nkeys); - key_objs = PyMem_New(PyObject *, (size_t) nkeys); - orig_key_objs = PyMem_New(PyObject *, (size_t) nkeys); + keys = PyMem_New(char *, orig_nkeys); + key_lens = PyMem_New(size_t, (size_t) orig_nkeys); + key_objs = PyMem_New(PyObject *, (size_t) orig_nkeys); + orig_key_objs = PyMem_New(PyObject *, (size_t) orig_nkeys); if (!keys || !key_lens || !key_objs || !orig_key_objs) { - PyMem_Free(keys); - PyMem_Free(key_lens); - PyMem_Free(key_objs); - PyMem_Free(orig_key_objs); - return PyErr_NoMemory(); + PyErr_NoMemory(); + goto memory_cleanup; } /* Clear potential previous exception, because we explicitly check for * exceptions as a loop predicate. */ PyErr_Clear(); - key_str_map = _PylibMC_map_str_keys(key_seq, orig_key_objs, &nkeys); + key_str_map = _PylibMC_map_str_keys(key_seq, orig_key_objs, &orig_nkeys); if (key_str_map == NULL) { - return NULL; + goto memory_cleanup; } /* Iterate through all keys and set lengths etc. */ Py_ssize_t key_idx = 0; - for (i = 0; i < nkeys; i++) { + for (i = 0; i < orig_nkeys; i++) { PyObject *ckey = orig_key_objs[i]; char *key; Py_ssize_t key_len; @@ -1716,8 +1797,12 @@ static PyObject *PylibMC_Client_get_multi( } /* Parse out value */ - val = _PylibMC_parse_memcached_result(res); - if (val == NULL) + val = _PylibMC_parse_memcached_result(self, res); + if (_PylibMC_cache_miss_simulated(val)) { + Py_DECREF(key_obj); + continue; + } + else if (val == NULL) goto loopcleanup; rc = PyDict_SetItem(retval, key_obj, val); @@ -1737,15 +1822,16 @@ static PyObject *PylibMC_Client_get_multi( } earlybird: - PyMem_Free(key_lens); - PyMem_Free(keys); - - for (i = 0; i < nkeys; i++) - Py_DECREF(key_objs[i]); for (i = 0; i < orig_nkeys; i++) Py_DECREF(orig_key_objs[i]); - PyMem_Free(key_objs); + for (i = 0; i < nkeys; i++) + Py_DECREF(key_objs[i]); _PylibMC_cleanup_str_key_mapping(key_str_map); + +memory_cleanup: + PyMem_Free(key_lens); + PyMem_Free(keys); + PyMem_Free(key_objs); PyMem_Free(orig_key_objs); if (results != NULL) { @@ -2306,35 +2392,12 @@ static PyObject *_PylibMC_GetPickles(const char *attname) { return pickle_attr; } -static PyObject *_PylibMC_Unpickle(const char *buff, size_t size) { - PyObject *pickle_load; - PyObject *retval = NULL; - - retval = NULL; - pickle_load = _PylibMC_GetPickles("loads"); - if (pickle_load != NULL) { -#if PY_MAJOR_VERSION >= 3 - retval = PyObject_CallFunction(pickle_load, "y#", buff, size); -#else - retval = PyObject_CallFunction(pickle_load, "s#", buff, size); -#endif - Py_DECREF(pickle_load); - } - - return retval; +static PyObject *_PylibMC_Unpickle(PyObject *val) { + return PyObject_CallFunctionObjArgs(_PylibMC_pickle_loads, val, NULL); } static PyObject *_PylibMC_Pickle(PyObject *val) { - PyObject *pickle_dump; - PyObject *retval = NULL; - - pickle_dump = _PylibMC_GetPickles("dumps"); - if (pickle_dump != NULL) { - retval = PyObject_CallFunction(pickle_dump, "Oi", val, -1); - Py_DECREF(pickle_dump); - } - - return retval; + return PyObject_CallFunctionObjArgs(_PylibMC_pickle_dumps, val, NULL); } /* }}} */ @@ -2491,9 +2554,14 @@ static void _make_excs(PyObject *module) { PylibMCExc_Error = PyErr_NewException( "pylibmc.Error", NULL, NULL); + PylibMCExc_CacheMiss = PyErr_NewException( + "_pylibmc.CacheMiss", PylibMCExc_Error, NULL); + exc_objs = PyList_New(0); PyList_Append(exc_objs, Py_BuildValue("sO", "Error", (PyObject *)PylibMCExc_Error)); + PyList_Append(exc_objs, + Py_BuildValue("sO", "CacheMiss", (PyObject *)PylibMCExc_CacheMiss)); for (err = PylibMCExc_mc_errs; err->name != NULL; err++) { char excnam[64]; @@ -2508,6 +2576,9 @@ static void _make_excs(PyObject *module) { PyModule_AddObject(module, "Error", (PyObject *)PylibMCExc_Error); + PyModule_AddObject(module, "CacheMiss", + (PyObject *)PylibMCExc_CacheMiss); + /* Backwards compatible name for <= pylibmc 1.2.3 * * Need to increase the refcount since we're adding another @@ -2595,6 +2666,14 @@ by using comma-separation. Good luck with that.\n", PylibMC_functions); _make_excs(module); + if (!(_PylibMC_pickle_loads = _PylibMC_GetPickles("loads"))) { + return MOD_ERROR_VAL; + } + + if (!(_PylibMC_pickle_dumps = _PylibMC_GetPickles("dumps"))) { + return MOD_ERROR_VAL; + } + PyModule_AddStringConstant(module, "__version__", PYLIBMC_VERSION); PyModule_ADD_REF(module, "client", (PyObject *)&PylibMC_ClientType); PyModule_AddStringConstant(module, diff --git a/src/_pylibmcmodule.h b/src/_pylibmcmodule.h index fe39542..e5da870 100644 --- a/src/_pylibmcmodule.h +++ b/src/_pylibmcmodule.h @@ -131,6 +131,7 @@ typedef struct { /* {{{ Exceptions */ static PyObject *PylibMCExc_Error; +static PyObject *PylibMCExc_CacheMiss; /* Mapping of memcached_return value -> Python exception object. */ typedef struct { @@ -279,6 +280,8 @@ static PylibMC_Client *PylibMC_ClientType_new(PyTypeObject *, PyObject *, PyObject *); static void PylibMC_ClientType_dealloc(PylibMC_Client *); static int PylibMC_Client_init(PylibMC_Client *, PyObject *, PyObject *); +static PyObject *PylibMC_Client_deserialize(PylibMC_Client *, PyObject *arg); +static PyObject *PylibMC_Client_serialize(PylibMC_Client *, PyObject *val); static PyObject *PylibMC_Client_get(PylibMC_Client *, PyObject *arg); static PyObject *PylibMC_Client_gets(PylibMC_Client *, PyObject *arg); static PyObject *PylibMC_Client_set(PylibMC_Client *, PyObject *, PyObject *); @@ -307,11 +310,12 @@ static PyObject *PylibMC_ErrFromMemcachedWithKey(PylibMC_Client *, const char *, memcached_return, const char *, Py_ssize_t); static PyObject *PylibMC_ErrFromMemcached(PylibMC_Client *, const char *, memcached_return); -static PyObject *_PylibMC_Unpickle(const char *, size_t); +static PyObject *_PylibMC_Unpickle(PyObject *); static PyObject *_PylibMC_Pickle(PyObject *); static int _key_normalized_obj(PyObject **); static int _key_normalized_str(char **, Py_ssize_t *); -static int _PylibMC_SerializeValue(PyObject *key_obj, +static int _PylibMC_SerializeValue(PylibMC_Client *self, + PyObject *key_obj, PyObject *key_prefix, PyObject *value_obj, time_t time, @@ -338,6 +342,12 @@ static bool _PylibMC_IncrDecr(PylibMC_Client *, pylibmc_incr *, size_t); /* {{{ Type's method table */ static PyMethodDef PylibMC_ClientType_methods[] = { + {"serialize", (PyCFunction)PylibMC_Client_serialize, METH_O, + "Serialize an object to a byte string and flag field, to be stored " + "in memcached."}, + {"deserialize", (PyCFunction)PylibMC_Client_deserialize, METH_VARARGS, + "Deserialize a bytestring and flag field retrieved from memcached. " + "Raise pylibmc.CacheMiss to simulate a cache miss."}, {"get", (PyCFunction)PylibMC_Client_get, METH_O, "Retrieve a key from a memcached."}, {"gets", (PyCFunction)PylibMC_Client_gets, METH_O, diff --git a/tests/__init__.py b/tests/__init__.py index 4f55c7e..146c015 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,8 @@ """Tests. They want YOU!!""" - from __future__ import print_function + +import gc +import sys import unittest import pylibmc from pylibmc.test import make_test_client @@ -29,3 +31,13 @@ def dump_infos(): print("Reported libmemcached version:", _pylibmc.libmemcached_version) print("Reported pylibmc version:", _pylibmc.__version__) print("Support compression:", _pylibmc.support_compression) + +def get_refcounts(refcountables): + """Measure reference counts during testing. + + Measuring reference counts typically changes them (since at least + one new reference is created as the argument to sys.getrefcount). + Therefore, try to do it in a consistent and deterministic fashion. + """ + gc.collect() + return [sys.getrefcount(val) for val in refcountables] diff --git a/tests/test_refcounts.py b/tests/test_refcounts.py index 702918c..92a5ac6 100644 --- a/tests/test_refcounts.py +++ b/tests/test_refcounts.py @@ -10,10 +10,7 @@ import _pylibmc from pylibmc.test import make_test_client from tests import PylibmcTestCase - - -def get_refcounts(refcountables): - return [sys.getrefcount(val) for val in refcountables] +from tests import get_refcounts class RefcountTests(PylibmcTestCase): @@ -47,9 +44,12 @@ def test_get_complex_type(self): def test_get_simple(self): self._test_get(b"refcountest2", 485295) + def test_get_singleton(self): + self._test_get(b"refcountest3", False) + def test_get_multi(self): bc = make_test_client(binary=True) - keys = ["first", "second"] + keys = ["first", "second", "", b""] value = "first_value" refcountables = keys + [value] initial_refcounts = get_refcounts(refcountables) diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..878a4cb --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,178 @@ +from __future__ import unicode_literals +from __future__ import print_function + +import json +import sys + +from nose.tools import eq_, ok_ + +import pylibmc +import _pylibmc +from pylibmc.test import make_test_client +from tests import PylibmcTestCase +from tests import get_refcounts + +def long_(val): + try: + return long(val) + except NameError: + # this happens under Python 3 + return val + +class SerializationTests(PylibmcTestCase): + """Test coverage for overriding serialization behavior in subclasses.""" + + def test_override_deserialize(self): + class MyClient(pylibmc.Client): + ignored = [] + def deserialize(self, bytes_, flags): + try: + return super(MyClient, self).deserialize(bytes_, flags) + except Exception as error: + self.ignored.append(error) + raise pylibmc.CacheMiss + + global MyObject # Needed by the pickling system. + class MyObject(object): + def __getstate__(self): + return dict(a=1) + def __eq__(self, other): + return type(other) is type(self) + def __setstate__(self, d): + assert d['a'] == 1 + + c = make_test_client(MyClient, behaviors={'cas': True}) + eq_(c.get('notathing'), None) + + refcountables = ['foo', 'myobj', 'noneobj', 'myobj2', 'cachemiss'] + initial_refcounts = get_refcounts(refcountables) + + c['foo'] = 'foo' + c['myobj'] = MyObject() + c['noneobj'] = None + c['myobj2'] = MyObject() + + # Show that everything is initially regular. + eq_(c.get('myobj'), MyObject()) + eq_(get_refcounts(refcountables), initial_refcounts) + eq_(c.get_multi(['foo', 'myobj', 'noneobj', 'cachemiss']), + dict(foo='foo', myobj=MyObject(), noneobj=None)) + eq_(get_refcounts(refcountables), initial_refcounts) + eq_(c.gets('myobj2')[0], MyObject()) + eq_(get_refcounts(refcountables), initial_refcounts) + + # Show that the subclass can transform unpickling issues into a cache miss. + del MyObject # Break unpickling + + eq_(c.get('myobj'), None) + eq_(get_refcounts(refcountables), initial_refcounts) + eq_(c.get_multi(['foo', 'myobj', 'noneobj', 'cachemiss']), + dict(foo='foo', noneobj=None)) + eq_(get_refcounts(refcountables), initial_refcounts) + eq_(c.gets('myobj2'), (None, None)) + eq_(get_refcounts(refcountables), initial_refcounts) + + # The ignored errors are "AttributeError: test.test_client has no MyObject" + eq_(len(MyClient.ignored), 3) + assert all(isinstance(error, AttributeError) for error in MyClient.ignored) + + def test_refcounts(self): + SENTINEL = object() + DUMMY = b"dummy" + KEY = b"fwLiDZKV7IlVByM5bVDNkg" + VALUE = "PVILgNVNkCfMkQup5vkGSQ" + + class MyClient(_pylibmc.client): + """Always serialize and deserialize to the same constants.""" + + def serialize(self, value): + return DUMMY, 1 + + def deserialize(self, bytes_, flags): + return SENTINEL + + refcountables = [1, long_(1), SENTINEL, DUMMY, KEY, VALUE] + c = make_test_client(MyClient) + initial_refcounts = get_refcounts(refcountables) + + c.set(KEY, VALUE) + eq_(get_refcounts(refcountables), initial_refcounts) + assert c.get(KEY) is SENTINEL + eq_(get_refcounts(refcountables), initial_refcounts) + eq_(c.get_multi([KEY]), {KEY: SENTINEL}) + eq_(get_refcounts(refcountables), initial_refcounts) + c.set_multi({KEY: True}) + eq_(get_refcounts(refcountables), initial_refcounts) + + def test_override_serialize(self): + class MyClient(pylibmc.Client): + def serialize(self, value): + return json.dumps(value).encode('utf-8'), 0 + + def deserialize(self, bytes_, flags): + return json.loads(bytes_.decode('utf-8')) + + c = make_test_client(MyClient) + c['foo'] = (1, 2, 3, 4) + # json turns tuples into lists: + eq_(c['foo'], [1, 2, 3, 4]) + + raised = False + try: + c['bar'] = object() + except TypeError: + raised = True + assert raised + + def _assert_set_raises(self, client, key, value): + """Assert that set operations raise a ValueError when appropriate. + + This is in a separate method to avoid confusing the reference counts. + """ + raised = False + try: + client[key] = value + except ValueError: + raised = True + assert raised + + def test_invalid_flags_returned(self): + # test that nothing bad (memory leaks, segfaults) happens + # when subclasses implement `deserialize` incorrectly + DUMMY = b"dummy" + BAD_FLAGS = object() + KEY = 'foo' + VALUE = object() + refcountables = [KEY, DUMMY, VALUE, BAD_FLAGS] + + class MyClient(pylibmc.Client): + def serialize(self, value): + return DUMMY, BAD_FLAGS + + c = make_test_client(MyClient) + initial_refcounts = get_refcounts(refcountables) + self._assert_set_raises(c, KEY, VALUE) + eq_(get_refcounts(refcountables), initial_refcounts) + + def test_invalid_flags_returned_2(self): + DUMMY = "ab" + KEY = "key" + VALUE = 123456 + refcountables = [DUMMY, KEY, VALUE] + + class MyClient(pylibmc.Client): + def serialize(self, value): + return DUMMY + + c = make_test_client(MyClient) + initial_refcounts = get_refcounts(refcountables) + + self._assert_set_raises(c, KEY, VALUE) + eq_(get_refcounts(refcountables), initial_refcounts) + + try: + c.set_multi({KEY: DUMMY}) + except ValueError: + raised = True + assert raised + eq_(get_refcounts(refcountables), initial_refcounts)