diff --git a/lib/rotary_embedding_torch/__init__.py b/lib/rotary_embedding_torch/__init__.py new file mode 100644 index 0000000..3a2cfef --- /dev/null +++ b/lib/rotary_embedding_torch/__init__.py @@ -0,0 +1,6 @@ +from ..rotary_embedding_torch.rotary_embedding_torch import ( + apply_rotary_emb, + RotaryEmbedding, + apply_learned_rotations, + broadcat +) diff --git a/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e532687 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..bbd17f7 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc new file mode 100644 index 0000000..73761f6 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc differ diff --git a/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc new file mode 100644 index 0000000..37902f1 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc differ diff --git a/lib/rotary_embedding_torch/rotary_embedding_torch.py b/lib/rotary_embedding_torch/rotary_embedding_torch.py new file mode 100644 index 0000000..8937875 --- /dev/null +++ b/lib/rotary_embedding_torch/rotary_embedding_torch.py @@ -0,0 +1,291 @@ +from __future__ import annotations +from math import pi, log + +import torch +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + +from einops import rearrange, repeat + +from typing import Literal + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# broadcat, as tortoise-tts was using it + +def broadcat(tensors, dim = -1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim = dim) + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +@autocast(enabled = False) +def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t_left, t, t_right), dim = -1) + + return out.type(dtype) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent = False) + + def get_seq_pos(self, seq_len, device, dtype, offset = 0): + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) + + freqs = self.forward(seq, seq_len = seq_len, offset = offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + q_scale = k_scale = 1. + + if self.use_xpos: + seq = self.get_seq_pos(k_len, dtype = dtype, device = device) + + q_scale = self.get_scale(seq[-q_len:]).type(dtype) + k_scale = self.get_scale(seq).type(dtype) + + rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + + freqs = self.forward(seq, seq_len = seq_len) + scale = self.get_scale(seq, seq_len = seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: Tensor, + seq_len: int | None = None, + offset = 0 + ): + assert self.use_xpos + + should_cache = ( + self.cache_if_possible and + exists(seq_len) + ) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + if should_cache: + self.tmp_store('cached_scales', scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + @autocast(enabled = False) + def forward( + self, + t: Tensor, + seq_len = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and \ + not self.learned_freq and \ + exists(seq_len) and \ + self.freqs_for != 'pixel' + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache: + self.tmp_store('cached_freqs', freqs.detach()) + + return freqs diff --git a/lib/simplejson/__init__.py b/lib/simplejson/__init__.py new file mode 100644 index 0000000..2d1900d --- /dev/null +++ b/lib/simplejson/__init__.py @@ -0,0 +1,562 @@ +r"""JSON (JavaScript Object Notation) is a subset of +JavaScript syntax (ECMA-262 3rd edition) used as a lightweight data +interchange format. + +:mod:`simplejson` exposes an API familiar to users of the standard library +:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained +version of the :mod:`json` library contained in Python 2.6, but maintains +compatibility back to Python 2.5 and (currently) has significant performance +advantages, even without using the optional C extension for speedups. + +Encoding basic Python object hierarchies:: + + >>> import simplejson as json + >>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}]) + '["foo", {"bar": ["baz", null, 1.0, 2]}]' + >>> print(json.dumps("\"foo\bar")) + "\"foo\bar" + >>> print(json.dumps(u'\u1234')) + "\u1234" + >>> print(json.dumps('\\')) + "\\" + >>> print(json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True)) + {"a": 0, "b": 0, "c": 0} + >>> from simplejson.compat import StringIO + >>> io = StringIO() + >>> json.dump(['streaming API'], io) + >>> io.getvalue() + '["streaming API"]' + +Compact encoding:: + + >>> import simplejson as json + >>> obj = [1,2,3,{'4': 5, '6': 7}] + >>> json.dumps(obj, separators=(',',':'), sort_keys=True) + '[1,2,3,{"4":5,"6":7}]' + +Pretty printing:: + + >>> import simplejson as json + >>> print(json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ')) + { + "4": 5, + "6": 7 + } + +Decoding JSON:: + + >>> import simplejson as json + >>> obj = [u'foo', {u'bar': [u'baz', None, 1.0, 2]}] + >>> json.loads('["foo", {"bar":["baz", null, 1.0, 2]}]') == obj + True + >>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar' + True + >>> from simplejson.compat import StringIO + >>> io = StringIO('["streaming API"]') + >>> json.load(io)[0] == 'streaming API' + True + +Specializing JSON object decoding:: + + >>> import simplejson as json + >>> def as_complex(dct): + ... if '__complex__' in dct: + ... return complex(dct['real'], dct['imag']) + ... return dct + ... + >>> json.loads('{"__complex__": true, "real": 1, "imag": 2}', + ... object_hook=as_complex) + (1+2j) + >>> from decimal import Decimal + >>> json.loads('1.1', parse_float=Decimal) == Decimal('1.1') + True + +Specializing JSON object encoding:: + + >>> import simplejson as json + >>> def encode_complex(obj): + ... if isinstance(obj, complex): + ... return [obj.real, obj.imag] + ... raise TypeError('Object of type %s is not JSON serializable' % + ... obj.__class__.__name__) + ... + >>> json.dumps(2 + 1j, default=encode_complex) + '[2.0, 1.0]' + >>> json.JSONEncoder(default=encode_complex).encode(2 + 1j) + '[2.0, 1.0]' + >>> ''.join(json.JSONEncoder(default=encode_complex).iterencode(2 + 1j)) + '[2.0, 1.0]' + +Using simplejson.tool from the shell to validate and pretty-print:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 3 (char 2) + +Parsing multiple documents serialized as JSON lines (newline-delimited JSON):: + + >>> import simplejson as json + >>> def loads_lines(docs): + ... for doc in docs.splitlines(): + ... yield json.loads(doc) + ... + >>> sum(doc["count"] for doc in loads_lines('{"count":1}\n{"count":2}\n{"count":3}\n')) + 6 + +Serializing multiple objects to JSON lines (newline-delimited JSON):: + + >>> import simplejson as json + >>> def dumps_lines(objs): + ... for obj in objs: + ... yield json.dumps(obj, separators=(',',':')) + '\n' + ... + >>> ''.join(dumps_lines([{'count': 1}, {'count': 2}, {'count': 3}])) + '{"count":1}\n{"count":2}\n{"count":3}\n' + +""" +from __future__ import absolute_import +__version__ = '3.19.2' +__all__ = [ + 'dump', 'dumps', 'load', 'loads', + 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', + 'OrderedDict', 'simple_first', 'RawJSON' +] + +__author__ = 'Bob Ippolito ' + +from decimal import Decimal + +from .errors import JSONDecodeError +from .raw_json import RawJSON +from .decoder import JSONDecoder +from .encoder import JSONEncoder, JSONEncoderForHTML +def _import_OrderedDict(): + import collections + try: + return collections.OrderedDict + except AttributeError: + from . import ordered_dict + return ordered_dict.OrderedDict +OrderedDict = _import_OrderedDict() + +def _import_c_make_encoder(): + try: + from ._speedups import make_encoder + return make_encoder + except ImportError: + return None + +_default_encoder = JSONEncoder() + +def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=False, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, + iterable_as_array=False, **kw): + """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a + ``.write()``-supporting file-like object). + + If *skipkeys* is true then ``dict`` keys that are not basic types + (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If *ensure_ascii* is false (default: ``True``), then the output may + contain non-ASCII characters, so long as they do not need to be escaped + by JSON. When it is true, all non-ASCII characters are escaped. + + If *allow_nan* is true (default: ``False``), then out of range ``float`` + values (``nan``, ``inf``, ``-inf``) will be serialized to + their JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``) + instead of raising a ValueError. See + *ignore_nan* for ECMA-262 compliant behavior. + + If *indent* is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. + + If specified, *separators* should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. + + *encoding* is the character encoding for str instances, default is UTF-8. + + *default(obj)* is a function that should return a serializable version + of obj or raise ``TypeError``. The default simply raises ``TypeError``. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. Note that this is still a + lossy operation that will not round-trip correctly and should be used + sparingly. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. NOTE: You should use *default* or *for_json* instead + of subclassing whenever possible. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and not allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not iterable_as_array + and not bigint_as_string and not sort_keys + and not item_sort_key and not for_json + and not ignore_nan and int_as_string_bitcount is None + and not kw + ): + iterable = _default_encoder.iterencode(obj) + else: + if cls is None: + cls = JSONEncoder + iterable = cls(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, + default=default, use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + iterable_as_array=iterable_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, + **kw).iterencode(obj) + # could accelerate with writelines in some versions of Python, at + # a debuggability cost + for chunk in iterable: + fp.write(chunk) + + +def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=False, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, + iterable_as_array=False, **kw): + """Serialize ``obj`` to a JSON formatted ``str``. + + If ``skipkeys`` is true then ``dict`` keys that are not basic types + (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If *ensure_ascii* is false (default: ``True``), then the output may + contain non-ASCII characters, so long as they do not need to be escaped + by JSON. When it is true, all non-ASCII characters are escaped. + + If ``check_circular`` is false, then the circular reference check + for container types will be skipped and a circular reference will + result in an ``OverflowError`` (or worse). + + If *allow_nan* is true (default: ``False``), then out of range ``float`` + values (``nan``, ``inf``, ``-inf``) will be serialized to + their JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``) + instead of raising a ValueError. See + *ignore_nan* for ECMA-262 compliant behavior. + + If ``indent`` is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, ``separators`` should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. + + ``encoding`` is the character encoding for bytes instances, default is + UTF-8. + + ``default(obj)`` is a function that should return a serializable version + of obj or raise TypeError. The default simply raises TypeError. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If *bigint_as_string* is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. NOTE: You should use *default* instead of subclassing + whenever possible. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and not allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not iterable_as_array + and not bigint_as_string and not sort_keys + and not item_sort_key and not for_json + and not ignore_nan and int_as_string_bitcount is None + and not kw + ): + return _default_encoder.encode(obj) + if cls is None: + cls = JSONEncoder + return cls( + skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, default=default, + use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + iterable_as_array=iterable_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, + **kw).encode(obj) + + +_default_decoder = JSONDecoder() + + +def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, allow_nan=False, **kw): + """Deserialize ``fp`` (a ``.read()``-supporting file-like object containing + a JSON document as `str` or `bytes`) to a Python object. + + *encoding* determines the encoding used to interpret any + `bytes` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding `str` objects. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity`` + and enable the use of the deprecated *parse_constant*. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. + + """ + return loads(fp.read(), + encoding=encoding, cls=cls, object_hook=object_hook, + parse_float=parse_float, parse_int=parse_int, + parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, + use_decimal=use_decimal, allow_nan=allow_nan, **kw) + + +def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, allow_nan=False, **kw): + """Deserialize ``s`` (a ``str`` or ``unicode`` instance containing a JSON + document) to a Python object. + + *encoding* determines the encoding used to interpret any + :class:`bytes` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity`` + and enable the use of the deprecated *parse_constant*. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. + + """ + if (cls is None and encoding is None and object_hook is None and + parse_int is None and parse_float is None and + parse_constant is None and object_pairs_hook is None + and not use_decimal and not allow_nan and not kw): + return _default_decoder.decode(s) + if cls is None: + cls = JSONDecoder + if object_hook is not None: + kw['object_hook'] = object_hook + if object_pairs_hook is not None: + kw['object_pairs_hook'] = object_pairs_hook + if parse_float is not None: + kw['parse_float'] = parse_float + if parse_int is not None: + kw['parse_int'] = parse_int + if parse_constant is not None: + kw['parse_constant'] = parse_constant + if use_decimal: + if parse_float is not None: + raise TypeError("use_decimal=True implies parse_float=Decimal") + kw['parse_float'] = Decimal + if allow_nan: + kw['allow_nan'] = True + return cls(encoding=encoding, **kw).decode(s) + + +def _toggle_speedups(enabled): + from . import decoder as dec + from . import encoder as enc + from . import scanner as scan + c_make_encoder = _import_c_make_encoder() + if enabled: + dec.scanstring = dec.c_scanstring or dec.py_scanstring + enc.c_make_encoder = c_make_encoder + enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or + enc.py_encode_basestring_ascii) + scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner + else: + dec.scanstring = dec.py_scanstring + enc.c_make_encoder = None + enc.encode_basestring_ascii = enc.py_encode_basestring_ascii + scan.make_scanner = scan.py_make_scanner + dec.make_scanner = scan.make_scanner + global _default_decoder + _default_decoder = JSONDecoder() + global _default_encoder + _default_encoder = JSONEncoder() + +def simple_first(kv): + """Helper function to pass to item_sort_key to sort simple + elements to the top, then container elements. + """ + return (isinstance(kv[1], (list, dict, tuple)), kv[0]) diff --git a/lib/simplejson/__pycache__/__init__.cpython-310.pyc b/lib/simplejson/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3c25666 Binary files /dev/null and b/lib/simplejson/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/__init__.cpython-39.pyc b/lib/simplejson/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..101c31a Binary files /dev/null and b/lib/simplejson/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/compat.cpython-310.pyc b/lib/simplejson/__pycache__/compat.cpython-310.pyc new file mode 100644 index 0000000..36285f9 Binary files /dev/null and b/lib/simplejson/__pycache__/compat.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/compat.cpython-39.pyc b/lib/simplejson/__pycache__/compat.cpython-39.pyc new file mode 100644 index 0000000..6163c30 Binary files /dev/null and b/lib/simplejson/__pycache__/compat.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/decoder.cpython-310.pyc b/lib/simplejson/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000..28d0460 Binary files /dev/null and b/lib/simplejson/__pycache__/decoder.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/decoder.cpython-39.pyc b/lib/simplejson/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000..49dc5fb Binary files /dev/null and b/lib/simplejson/__pycache__/decoder.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/encoder.cpython-310.pyc b/lib/simplejson/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000..0ca1a57 Binary files /dev/null and b/lib/simplejson/__pycache__/encoder.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/encoder.cpython-39.pyc b/lib/simplejson/__pycache__/encoder.cpython-39.pyc new file mode 100644 index 0000000..bee2c08 Binary files /dev/null and b/lib/simplejson/__pycache__/encoder.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/errors.cpython-310.pyc b/lib/simplejson/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000..b9256e7 Binary files /dev/null and b/lib/simplejson/__pycache__/errors.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/errors.cpython-39.pyc b/lib/simplejson/__pycache__/errors.cpython-39.pyc new file mode 100644 index 0000000..28852b8 Binary files /dev/null and b/lib/simplejson/__pycache__/errors.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/ordered_dict.cpython-39.pyc b/lib/simplejson/__pycache__/ordered_dict.cpython-39.pyc new file mode 100644 index 0000000..883171b Binary files /dev/null and b/lib/simplejson/__pycache__/ordered_dict.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/raw_json.cpython-310.pyc b/lib/simplejson/__pycache__/raw_json.cpython-310.pyc new file mode 100644 index 0000000..622d440 Binary files /dev/null and b/lib/simplejson/__pycache__/raw_json.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/raw_json.cpython-39.pyc b/lib/simplejson/__pycache__/raw_json.cpython-39.pyc new file mode 100644 index 0000000..8b5399f Binary files /dev/null and b/lib/simplejson/__pycache__/raw_json.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/scanner.cpython-310.pyc b/lib/simplejson/__pycache__/scanner.cpython-310.pyc new file mode 100644 index 0000000..6798b1d Binary files /dev/null and b/lib/simplejson/__pycache__/scanner.cpython-310.pyc differ diff --git a/lib/simplejson/__pycache__/scanner.cpython-39.pyc b/lib/simplejson/__pycache__/scanner.cpython-39.pyc new file mode 100644 index 0000000..f382a79 Binary files /dev/null and b/lib/simplejson/__pycache__/scanner.cpython-39.pyc differ diff --git a/lib/simplejson/__pycache__/tool.cpython-39.pyc b/lib/simplejson/__pycache__/tool.cpython-39.pyc new file mode 100644 index 0000000..055935f Binary files /dev/null and b/lib/simplejson/__pycache__/tool.cpython-39.pyc differ diff --git a/lib/simplejson/compat.py b/lib/simplejson/compat.py new file mode 100644 index 0000000..5fc1412 --- /dev/null +++ b/lib/simplejson/compat.py @@ -0,0 +1,34 @@ +"""Python 3 compatibility shims +""" +import sys +if sys.version_info[0] < 3: + PY3 = False + def b(s): + return s + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + BytesIO = StringIO + text_type = unicode + binary_type = str + string_types = (basestring,) + integer_types = (int, long) + unichr = unichr + reload_module = reload +else: + PY3 = True + if sys.version_info[:2] >= (3, 4): + from importlib import reload as reload_module + else: + from imp import reload as reload_module + def b(s): + return bytes(s, 'latin1') + from io import StringIO, BytesIO + text_type = str + binary_type = bytes + string_types = (str,) + integer_types = (int,) + unichr = chr + +long_type = integer_types[-1] diff --git a/lib/simplejson/decoder.py b/lib/simplejson/decoder.py new file mode 100644 index 0000000..c99a976 --- /dev/null +++ b/lib/simplejson/decoder.py @@ -0,0 +1,416 @@ +"""Implementation of JSONDecoder +""" +from __future__ import absolute_import +import re +import sys +import struct +from .compat import PY3, unichr +from .scanner import make_scanner, JSONDecodeError + +def _import_c_scanstring(): + try: + from ._speedups import scanstring + return scanstring + except ImportError: + return None +c_scanstring = _import_c_scanstring() + +# NOTE (3.1.0): JSONDecodeError may still be imported from this module for +# compatibility, but it was never in the __all__ +__all__ = ['JSONDecoder'] + +FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL + +def _floatconstants(): + if sys.version_info < (2, 6): + _BYTES = '7FF80000000000007FF0000000000000'.decode('hex') + nan, inf = struct.unpack('>dd', _BYTES) + else: + nan = float('nan') + inf = float('inf') + return nan, inf, -inf + +NaN, PosInf, NegInf = _floatconstants() + +_CONSTANTS = { + '-Infinity': NegInf, + 'Infinity': PosInf, + 'NaN': NaN, +} + +STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS) +BACKSLASH = { + '"': u'"', '\\': u'\\', '/': u'/', + 'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t', +} + +DEFAULT_ENCODING = "utf-8" + +if hasattr(sys, 'get_int_max_str_digits'): + bounded_int = int +else: + def bounded_int(s, INT_MAX_STR_DIGITS=4300): + """Backport of the integer string length conversion limitation + + https://docs.python.org/3/library/stdtypes.html#int-max-str-digits + """ + if len(s) > INT_MAX_STR_DIGITS: + raise ValueError("Exceeds the limit (%s) for integer string conversion: value has %s digits" % (INT_MAX_STR_DIGITS, len(s))) + return int(s) + + +def scan_four_digit_hex(s, end, _m=re.compile(r'^[0-9a-fA-F]{4}$').match): + """Scan a four digit hex number from s[end:end + 4] + """ + msg = "Invalid \\uXXXX escape sequence" + esc = s[end:end + 4] + if not _m(esc): + raise JSONDecodeError(msg, s, end - 2) + try: + return int(esc, 16), end + 4 + except ValueError: + raise JSONDecodeError(msg, s, end - 2) + +def py_scanstring(s, end, encoding=None, strict=True, + _b=BACKSLASH, _m=STRINGCHUNK.match, _join=u''.join, + _PY3=PY3, _maxunicode=sys.maxunicode, + _scan_four_digit_hex=scan_four_digit_hex): + """Scan the string s for a JSON string. End is the index of the + character in s after the quote that started the JSON string. + Unescapes all valid JSON string escape sequences and raises ValueError + on attempt to decode an invalid string. If strict is False then literal + control characters are allowed in the string. + + Returns a tuple of the decoded string and the index of the character in s + after the end quote.""" + if encoding is None: + encoding = DEFAULT_ENCODING + chunks = [] + _append = chunks.append + begin = end - 1 + while 1: + chunk = _m(s, end) + if chunk is None: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + prev_end = end + end = chunk.end() + content, terminator = chunk.groups() + # Content is contains zero or more unescaped string characters + if content: + if not _PY3 and not isinstance(content, unicode): + content = unicode(content, encoding) + _append(content) + # Terminator is the end of string, a literal control character, + # or a backslash denoting that an escape sequence follows + if terminator == '"': + break + elif terminator != '\\': + if strict: + msg = "Invalid control character %r at" + raise JSONDecodeError(msg, s, prev_end) + else: + _append(terminator) + continue + try: + esc = s[end] + except IndexError: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + # If not a unicode escape sequence, must be in the lookup table + if esc != 'u': + try: + char = _b[esc] + except KeyError: + msg = "Invalid \\X escape sequence %r" + raise JSONDecodeError(msg, s, end) + end += 1 + else: + # Unicode escape sequence + uni, end = _scan_four_digit_hex(s, end + 1) + # Check for surrogate pair on UCS-4 systems + # Note that this will join high/low surrogate pairs + # but will also pass unpaired surrogates through + if (_maxunicode > 65535 and + uni & 0xfc00 == 0xd800 and + s[end:end + 2] == '\\u'): + uni2, end2 = _scan_four_digit_hex(s, end + 2) + if uni2 & 0xfc00 == 0xdc00: + uni = 0x10000 + (((uni - 0xd800) << 10) | + (uni2 - 0xdc00)) + end = end2 + char = unichr(uni) + # Append the unescaped character + _append(char) + return _join(chunks), end + + +# Use speedup if available +scanstring = c_scanstring or py_scanstring + +WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS) +WHITESPACE_STR = ' \t\n\r' + +def JSONObject(state, encoding, strict, scan_once, object_hook, + object_pairs_hook, memo=None, + _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state + # Backwards compatibility + if memo is None: + memo = {} + memo_get = memo.setdefault + pairs = [] + # Use a slice to prevent IndexError from being raised, the following + # check will raise a more specific ValueError if the string is empty + nextchar = s[end:end + 1] + # Normally we expect nextchar == '"' + if nextchar != '"': + if nextchar in _ws: + end = _w(s, end).end() + nextchar = s[end:end + 1] + # Trivial empty object + if nextchar == '}': + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + 1 + pairs = {} + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + 1 + elif nextchar != '"': + raise JSONDecodeError( + "Expecting property name enclosed in double quotes or '}'", + s, end) + end += 1 + while True: + key, end = scanstring(s, end, encoding, strict) + key = memo_get(key, key) + + # To skip some function call overhead we optimize the fast paths where + # the JSON key separator is ": " or just ":". + if s[end:end + 1] != ':': + end = _w(s, end).end() + if s[end:end + 1] != ':': + raise JSONDecodeError("Expecting ':' delimiter", s, end) + + end += 1 + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + value, end = scan_once(s, end) + pairs.append((key, value)) + + try: + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + end += 1 + + if nextchar == '}': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting ',' delimiter or '}'", s, end - 1) + + try: + nextchar = s[end] + if nextchar in _ws: + end += 1 + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + + end += 1 + if nextchar != '"': + raise JSONDecodeError( + "Expecting property name enclosed in double quotes", + s, end - 1) + + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + pairs = dict(pairs) + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + +def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state + values = [] + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + # Look-ahead for trivial empty array + if nextchar == ']': + return values, end + 1 + elif nextchar == '': + raise JSONDecodeError("Expecting value or ']'", s, end) + _append = values.append + while True: + value, end = scan_once(s, end) + _append(value) + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + end += 1 + if nextchar == ']': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting ',' delimiter or ']'", s, end - 1) + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + return values, end + +class JSONDecoder(object): + """Simple JSON decoder + + Performs the following translations in decoding by default: + + +---------------+-------------------+ + | JSON | Python | + +===============+===================+ + | object | dict | + +---------------+-------------------+ + | array | list | + +---------------+-------------------+ + | string | str, unicode | + +---------------+-------------------+ + | number (int) | int, long | + +---------------+-------------------+ + | number (real) | float | + +---------------+-------------------+ + | true | True | + +---------------+-------------------+ + | false | False | + +---------------+-------------------+ + | null | None | + +---------------+-------------------+ + + When allow_nan=True, it also understands + ``NaN``, ``Infinity``, and ``-Infinity`` as + their corresponding ``float`` values, which is outside the JSON spec. + + """ + + def __init__(self, encoding=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, strict=True, + object_pairs_hook=None, allow_nan=False): + """ + *encoding* determines the encoding used to interpret any + :class:`str` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + Note that currently only encodings that are a superset of ASCII work, + strings of other encodings should be passed in as :class:`unicode`. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + *strict* controls the parser's behavior when it encounters an + invalid control character in a string. The default setting of + ``True`` means that unescaped control characters are parse errors, if + ``False`` then control characters will be allowed in strings. + + """ + if encoding is None: + encoding = DEFAULT_ENCODING + self.encoding = encoding + self.object_hook = object_hook + self.object_pairs_hook = object_pairs_hook + self.parse_float = parse_float or float + self.parse_int = parse_int or bounded_int + self.parse_constant = parse_constant or (allow_nan and _CONSTANTS.__getitem__ or None) + self.strict = strict + self.parse_object = JSONObject + self.parse_array = JSONArray + self.parse_string = scanstring + self.memo = {} + self.scan_once = make_scanner(self) + + def decode(self, s, _w=WHITESPACE.match, _PY3=PY3): + """Return the Python representation of ``s`` (a ``str`` or ``unicode`` + instance containing a JSON document) + + """ + if _PY3 and isinstance(s, bytes): + s = str(s, self.encoding) + obj, end = self.raw_decode(s) + end = _w(s, end).end() + if end != len(s): + raise JSONDecodeError("Extra data", s, end, len(s)) + return obj + + def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3): + """Decode a JSON document from ``s`` (a ``str`` or ``unicode`` + beginning with a JSON document) and return a 2-tuple of the Python + representation and the index in ``s`` where the document ended. + Optionally, ``idx`` can be used to specify an offset in ``s`` where + the JSON document begins. + + This can be used to decode a JSON document from a string that may + have extraneous data at the end. + + """ + if idx < 0: + # Ensure that raw_decode bails on negative indexes, the regex + # would otherwise mask this behavior. #98 + raise JSONDecodeError('Expecting value', s, idx) + if _PY3 and not isinstance(s, str): + raise TypeError("Input string must be text, not bytes") + # strip UTF-8 bom + if len(s) > idx: + ord0 = ord(s[idx]) + if ord0 == 0xfeff: + idx += 1 + elif ord0 == 0xef and s[idx:idx + 3] == '\xef\xbb\xbf': + idx += 3 + return self.scan_once(s, idx=_w(s, idx).end()) diff --git a/lib/simplejson/encoder.py b/lib/simplejson/encoder.py new file mode 100644 index 0000000..661ff36 --- /dev/null +++ b/lib/simplejson/encoder.py @@ -0,0 +1,740 @@ +"""Implementation of JSONEncoder +""" +from __future__ import absolute_import +import re +from operator import itemgetter +# Do not import Decimal directly to avoid reload issues +import decimal +from .compat import binary_type, text_type, string_types, integer_types, PY3 +def _import_speedups(): + try: + from . import _speedups + return _speedups.encode_basestring_ascii, _speedups.make_encoder + except ImportError: + return None, None +c_encode_basestring_ascii, c_make_encoder = _import_speedups() + +from .decoder import PosInf +from .raw_json import RawJSON + +ESCAPE = re.compile(r'[\x00-\x1f\\"]') +ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') +HAS_UTF8 = re.compile(r'[\x80-\xff]') +ESCAPE_DCT = { + '\\': '\\\\', + '"': '\\"', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', +} +for i in range(0x20): + #ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) + ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) +del i + +FLOAT_REPR = repr + +def encode_basestring(s, _PY3=PY3, _q=u'"'): + """Return a JSON representation of a Python string + + """ + if _PY3: + if isinstance(s, bytes): + s = str(s, 'utf-8') + elif type(s) is not str: + # convert an str subclass instance to exact str + # raise a TypeError otherwise + s = str.__str__(s) + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = unicode(s, 'utf-8') + elif type(s) not in (str, unicode): + # convert an str subclass instance to exact str + # convert a unicode subclass instance to exact unicode + # raise a TypeError otherwise + if isinstance(s, str): + s = str.__str__(s) + else: + s = unicode.__getnewargs__(s)[0] + def replace(match): + return ESCAPE_DCT[match.group(0)] + return _q + ESCAPE.sub(replace, s) + _q + + +def py_encode_basestring_ascii(s, _PY3=PY3): + """Return an ASCII-only JSON representation of a Python string + + """ + if _PY3: + if isinstance(s, bytes): + s = str(s, 'utf-8') + elif type(s) is not str: + # convert an str subclass instance to exact str + # raise a TypeError otherwise + s = str.__str__(s) + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = unicode(s, 'utf-8') + elif type(s) not in (str, unicode): + # convert an str subclass instance to exact str + # convert a unicode subclass instance to exact unicode + # raise a TypeError otherwise + if isinstance(s, str): + s = str.__str__(s) + else: + s = unicode.__getnewargs__(s)[0] + def replace(match): + s = match.group(0) + try: + return ESCAPE_DCT[s] + except KeyError: + n = ord(s) + if n < 0x10000: + #return '\\u{0:04x}'.format(n) + return '\\u%04x' % (n,) + else: + # surrogate pair + n -= 0x10000 + s1 = 0xd800 | ((n >> 10) & 0x3ff) + s2 = 0xdc00 | (n & 0x3ff) + #return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) + return '\\u%04x\\u%04x' % (s1, s2) + return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' + + +encode_basestring_ascii = ( + c_encode_basestring_ascii or py_encode_basestring_ascii) + +class JSONEncoder(object): + """Extensible JSON encoder for Python data structures. + + Supports the following objects and types by default: + + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict, namedtuple | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str, unicode | string | + +-------------------+---------------+ + | int, long, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ + + To extend this to recognize other objects, subclass and implement a + ``.default()`` method with another method that returns a serializable + object for ``o`` if possible, otherwise it should call the superclass + implementation (to raise ``TypeError``). + + """ + item_separator = ', ' + key_separator = ': ' + + def __init__(self, skipkeys=False, ensure_ascii=True, + check_circular=True, allow_nan=False, sort_keys=False, + indent=None, separators=None, encoding='utf-8', default=None, + use_decimal=True, namedtuple_as_object=True, + tuple_as_array=True, bigint_as_string=False, + item_sort_key=None, for_json=False, ignore_nan=False, + int_as_string_bitcount=None, iterable_as_array=False): + """Constructor for JSONEncoder, with sensible defaults. + + If skipkeys is false, then it is a TypeError to attempt + encoding of keys that are not str, int, long, float or None. If + skipkeys is True, such items are simply skipped. + + If ensure_ascii is true, the output is guaranteed to be str + objects with all incoming unicode characters escaped. If + ensure_ascii is false, the output will be unicode object. + + If check_circular is true, then lists, dicts, and custom encoded + objects will be checked for circular references during encoding to + prevent an infinite recursion (which would cause an OverflowError). + Otherwise, no such check takes place. + + If allow_nan is true (default: False), then out of range float + values (nan, inf, -inf) will be serialized to + their JavaScript equivalents (NaN, Infinity, -Infinity) + instead of raising a ValueError. See + ignore_nan for ECMA-262 compliant behavior. + + If sort_keys is true, then the output of dictionaries will be + sorted by key; this is useful for regression tests to ensure + that JSON serializations can be compared on a day-to-day basis. + + If indent is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, separators should be an (item_separator, key_separator) + tuple. The default is (', ', ': ') if *indent* is ``None`` and + (',', ': ') otherwise. To get the most compact JSON representation, + you should specify (',', ':') to eliminate whitespace. + + If specified, default is a function that gets called for objects + that can't otherwise be serialized. It should return a JSON encodable + version of the object or raise a ``TypeError``. + + If encoding is not None, then all input strings will be + transformed into unicode using that encoding prior to JSON-encoding. + The default is UTF-8. + + If use_decimal is true (default: ``True``), ``decimal.Decimal`` will + be supported directly by the encoder. For the inverse, decode JSON + with ``parse_float=decimal.Decimal``. + + If namedtuple_as_object is true (the default), objects with + ``_asdict()`` methods will be encoded as JSON objects. + + If tuple_as_array is true (the default), tuple (and subclasses) will + be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If bigint_as_string is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If int_as_string_bitcount is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, item_sort_key is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. + + If for_json is true (not the default), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized + as ``null`` in compliance with the ECMA-262 specification. If true, + this will override *allow_nan*. + + """ + + self.skipkeys = skipkeys + self.ensure_ascii = ensure_ascii + self.check_circular = check_circular + self.allow_nan = allow_nan + self.sort_keys = sort_keys + self.use_decimal = use_decimal + self.namedtuple_as_object = namedtuple_as_object + self.tuple_as_array = tuple_as_array + self.iterable_as_array = iterable_as_array + self.bigint_as_string = bigint_as_string + self.item_sort_key = item_sort_key + self.for_json = for_json + self.ignore_nan = ignore_nan + self.int_as_string_bitcount = int_as_string_bitcount + if indent is not None and not isinstance(indent, string_types): + indent = indent * ' ' + self.indent = indent + if separators is not None: + self.item_separator, self.key_separator = separators + elif indent is not None: + self.item_separator = ',' + if default is not None: + self.default = default + self.encoding = encoding + + def default(self, o): + """Implement this method in a subclass such that it returns + a serializable object for ``o``, or calls the base implementation + (to raise a ``TypeError``). + + For example, to support arbitrary iterators, you could + implement default like this:: + + def default(self, o): + try: + iterable = iter(o) + except TypeError: + pass + else: + return list(iterable) + return JSONEncoder.default(self, o) + + """ + raise TypeError('Object of type %s is not JSON serializable' % + o.__class__.__name__) + + def encode(self, o): + """Return a JSON string representation of a Python data structure. + + >>> from simplejson import JSONEncoder + >>> JSONEncoder().encode({"foo": ["bar", "baz"]}) + '{"foo": ["bar", "baz"]}' + + """ + # This is for extremely simple cases and benchmarks. + if isinstance(o, binary_type): + _encoding = self.encoding + if (_encoding is not None and not (_encoding == 'utf-8')): + o = text_type(o, _encoding) + if isinstance(o, string_types): + if self.ensure_ascii: + return encode_basestring_ascii(o) + else: + return encode_basestring(o) + # This doesn't pass the iterator directly to ''.join() because the + # exceptions aren't as detailed. The list call should be roughly + # equivalent to the PySequence_Fast that ''.join() would do. + chunks = self.iterencode(o) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + if self.ensure_ascii: + return ''.join(chunks) + else: + return u''.join(chunks) + + def iterencode(self, o): + """Encode the given object and yield each string + representation as available. + + For example:: + + for chunk in JSONEncoder().iterencode(bigobject): + mysocket.write(chunk) + + """ + if self.check_circular: + markers = {} + else: + markers = None + if self.ensure_ascii: + _encoder = encode_basestring_ascii + else: + _encoder = encode_basestring + if self.encoding != 'utf-8' and self.encoding is not None: + def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding): + if isinstance(o, binary_type): + o = text_type(o, _encoding) + return _orig_encoder(o) + + def floatstr(o, allow_nan=self.allow_nan, ignore_nan=self.ignore_nan, + _repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf): + # Check for specials. Note that this type of test is processor + # and/or platform-specific, so do tests which don't depend on + # the internals. + + if o != o: + text = 'NaN' + elif o == _inf: + text = 'Infinity' + elif o == _neginf: + text = '-Infinity' + else: + if type(o) != float: + # See #118, do not trust custom str/repr + o = float(o) + return _repr(o) + + if ignore_nan: + text = 'null' + elif not allow_nan: + raise ValueError( + "Out of range float values are not JSON compliant: " + + repr(o)) + + return text + + key_memo = {} + int_as_string_bitcount = ( + 53 if self.bigint_as_string else self.int_as_string_bitcount) + if (c_make_encoder is not None and self.indent is None): + _iterencode = c_make_encoder( + markers, self.default, _encoder, self.indent, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.allow_nan, key_memo, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.ignore_nan, decimal.Decimal, self.iterable_as_array) + else: + _iterencode = _make_iterencode( + markers, self.default, _encoder, self.indent, floatstr, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.iterable_as_array, Decimal=decimal.Decimal) + try: + return _iterencode(o, 0) + finally: + key_memo.clear() + + +class JSONEncoderForHTML(JSONEncoder): + """An encoder that produces JSON safe to embed in HTML. + + To embed JSON content in, say, a script tag on a web page, the + characters &, < and > should be escaped. They cannot be escaped + with the usual entities (e.g. &) because they are not expanded + within ' + self.assertEqual( + r'"\u003c/script\u003e\u003cscript\u003e' + r'alert(\"gotcha\")\u003c/script\u003e"', + self.encoder.encode(bad_string)) + self.assertEqual( + bad_string, self.decoder.decode( + self.encoder.encode(bad_string))) diff --git a/lib/simplejson/tests/test_errors.py b/lib/simplejson/tests/test_errors.py new file mode 100644 index 0000000..d573825 --- /dev/null +++ b/lib/simplejson/tests/test_errors.py @@ -0,0 +1,68 @@ +import sys, pickle +from unittest import TestCase + +import simplejson as json +from simplejson.compat import text_type, b + +class TestErrors(TestCase): + def test_string_keys_error(self): + data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}] + try: + json.dumps(data) + except TypeError: + err = sys.exc_info()[1] + else: + self.fail('Expected TypeError') + self.assertEqual(str(err), + 'keys must be str, int, float, bool or None, not tuple') + + def test_not_serializable(self): + try: + json.dumps(json) + except TypeError: + err = sys.exc_info()[1] + else: + self.fail('Expected TypeError') + self.assertEqual(str(err), + 'Object of type module is not JSON serializable') + + def test_decode_error(self): + err = None + try: + json.loads('{}\na\nb') + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + self.assertEqual(err.lineno, 2) + self.assertEqual(err.colno, 1) + self.assertEqual(err.endlineno, 3) + self.assertEqual(err.endcolno, 2) + + def test_scan_error(self): + err = None + for t in (text_type, b): + try: + json.loads(t('{"asdf": "')) + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, 10) + + def test_error_is_pickable(self): + err = None + try: + json.loads('{}\na\nb') + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + s = pickle.dumps(err) + e = pickle.loads(s) + + self.assertEqual(err.msg, e.msg) + self.assertEqual(err.doc, e.doc) + self.assertEqual(err.pos, e.pos) + self.assertEqual(err.end, e.end) diff --git a/lib/simplejson/tests/test_fail.py b/lib/simplejson/tests/test_fail.py new file mode 100644 index 0000000..5f9a8f6 --- /dev/null +++ b/lib/simplejson/tests/test_fail.py @@ -0,0 +1,178 @@ +import sys +from unittest import TestCase + +import simplejson as json + +# 2007-10-05 +JSONDOCS = [ + # http://json.org/JSON_checker/test/fail1.json + '"A JSON payload should be an object or array, not a string."', + # http://json.org/JSON_checker/test/fail2.json + '["Unclosed array"', + # http://json.org/JSON_checker/test/fail3.json + '{unquoted_key: "keys must be quoted"}', + # http://json.org/JSON_checker/test/fail4.json + '["extra comma",]', + # http://json.org/JSON_checker/test/fail5.json + '["double extra comma",,]', + # http://json.org/JSON_checker/test/fail6.json + '[ , "<-- missing value"]', + # http://json.org/JSON_checker/test/fail7.json + '["Comma after the close"],', + # http://json.org/JSON_checker/test/fail8.json + '["Extra close"]]', + # http://json.org/JSON_checker/test/fail9.json + '{"Extra comma": true,}', + # http://json.org/JSON_checker/test/fail10.json + '{"Extra value after close": true} "misplaced quoted value"', + # http://json.org/JSON_checker/test/fail11.json + '{"Illegal expression": 1 + 2}', + # http://json.org/JSON_checker/test/fail12.json + '{"Illegal invocation": alert()}', + # http://json.org/JSON_checker/test/fail13.json + '{"Numbers cannot have leading zeroes": 013}', + # http://json.org/JSON_checker/test/fail14.json + '{"Numbers cannot be hex": 0x14}', + # http://json.org/JSON_checker/test/fail15.json + '["Illegal backslash escape: \\x15"]', + # http://json.org/JSON_checker/test/fail16.json + '[\\naked]', + # http://json.org/JSON_checker/test/fail17.json + '["Illegal backslash escape: \\017"]', + # http://json.org/JSON_checker/test/fail18.json + '[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', + # http://json.org/JSON_checker/test/fail19.json + '{"Missing colon" null}', + # http://json.org/JSON_checker/test/fail20.json + '{"Double colon":: null}', + # http://json.org/JSON_checker/test/fail21.json + '{"Comma instead of colon", null}', + # http://json.org/JSON_checker/test/fail22.json + '["Colon instead of comma": false]', + # http://json.org/JSON_checker/test/fail23.json + '["Bad value", truth]', + # http://json.org/JSON_checker/test/fail24.json + "['single quote']", + # http://json.org/JSON_checker/test/fail25.json + '["\ttab\tcharacter\tin\tstring\t"]', + # http://json.org/JSON_checker/test/fail26.json + '["tab\\ character\\ in\\ string\\ "]', + # http://json.org/JSON_checker/test/fail27.json + '["line\nbreak"]', + # http://json.org/JSON_checker/test/fail28.json + '["line\\\nbreak"]', + # http://json.org/JSON_checker/test/fail29.json + '[0e]', + # http://json.org/JSON_checker/test/fail30.json + '[0e+]', + # http://json.org/JSON_checker/test/fail31.json + '[0e+-1]', + # http://json.org/JSON_checker/test/fail32.json + '{"Comma instead if closing brace": true,', + # http://json.org/JSON_checker/test/fail33.json + '["mismatch"}', + # http://code.google.com/p/simplejson/issues/detail?id=3 + u'["A\u001FZ control characters in string"]', + # misc based on coverage + '{', + '{]', + '{"foo": "bar"]', + '{"foo": "bar"', + 'nul', + 'nulx', + '-', + '-x', + '-e', + '-e0', + '-Infinite', + '-Inf', + 'Infinit', + 'Infinite', + 'NaM', + 'NuN', + 'falsy', + 'fal', + 'trug', + 'tru', + '1e', + '1ex', + '1e-', + '1e-x', +] + +SKIPS = { + 1: "why not have a string payload?", + 18: "spec doesn't specify any nesting limitations", +} + +class TestFail(TestCase): + def test_failures(self): + for idx, doc in enumerate(JSONDOCS): + idx = idx + 1 + if idx in SKIPS: + json.loads(doc) + continue + try: + json.loads(doc) + except json.JSONDecodeError: + pass + else: + self.fail("Expected failure for fail%d.json: %r" % (idx, doc)) + + def test_array_decoder_issue46(self): + # http://code.google.com/p/simplejson/issues/detail?id=46 + for doc in [u'[,]', '[,]']: + try: + json.loads(doc) + except json.JSONDecodeError: + e = sys.exc_info()[1] + self.assertEqual(e.pos, 1) + self.assertEqual(e.lineno, 1) + self.assertEqual(e.colno, 2) + except Exception: + e = sys.exc_info()[1] + self.fail("Unexpected exception raised %r %s" % (e, e)) + else: + self.fail("Unexpected success parsing '[,]'") + + def test_truncated_input(self): + test_cases = [ + ('', 'Expecting value', 0), + ('[', "Expecting value or ']'", 1), + ('[42', "Expecting ',' delimiter", 3), + ('[42,', 'Expecting value', 4), + ('["', 'Unterminated string starting at', 1), + ('["spam', 'Unterminated string starting at', 1), + ('["spam"', "Expecting ',' delimiter", 7), + ('["spam",', 'Expecting value', 8), + ('{', "Expecting property name enclosed in double quotes or '}'", 1), + ('{"', 'Unterminated string starting at', 1), + ('{"spam', 'Unterminated string starting at', 1), + ('{"spam"', "Expecting ':' delimiter", 7), + ('{"spam":', 'Expecting value', 8), + ('{"spam":42', "Expecting ',' delimiter", 10), + ('{"spam":42,', 'Expecting property name enclosed in double quotes', + 11), + ('"', 'Unterminated string starting at', 0), + ('"spam', 'Unterminated string starting at', 0), + ('[,', "Expecting value", 1), + ('--', 'Expecting value', 0), + ('"\x18d', "Invalid control character %r", 1), + ] + for data, msg, idx in test_cases: + try: + json.loads(data) + except json.JSONDecodeError: + e = sys.exc_info()[1] + self.assertEqual( + e.msg[:len(msg)], + msg, + "%r doesn't start with %r for %r" % (e.msg, msg, data)) + self.assertEqual( + e.pos, idx, + "pos %r != %r for %r" % (e.pos, idx, data)) + except Exception: + e = sys.exc_info()[1] + self.fail("Unexpected exception raised %r %s" % (e, e)) + else: + self.fail("Unexpected success parsing '%r'" % (data,)) diff --git a/lib/simplejson/tests/test_float.py b/lib/simplejson/tests/test_float.py new file mode 100644 index 0000000..a977969 --- /dev/null +++ b/lib/simplejson/tests/test_float.py @@ -0,0 +1,38 @@ +import math +from unittest import TestCase +from simplejson.compat import long_type, text_type +import simplejson as json +from simplejson.decoder import NaN, PosInf, NegInf + +class TestFloat(TestCase): + def test_degenerates_allow(self): + for inf in (PosInf, NegInf): + self.assertEqual(json.loads(json.dumps(inf, allow_nan=True), allow_nan=True), inf) + # Python 2.5 doesn't have math.isnan + nan = json.loads(json.dumps(NaN, allow_nan=True), allow_nan=True) + self.assertTrue((0 + nan) != nan) + + def test_degenerates_ignore(self): + for f in (PosInf, NegInf, NaN): + self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None) + + def test_degenerates_deny(self): + for f in (PosInf, NegInf, NaN): + self.assertRaises(ValueError, json.dumps, f, allow_nan=False) + for s in ('Infinity', '-Infinity', 'NaN'): + self.assertRaises(ValueError, json.loads, s, allow_nan=False) + self.assertRaises(ValueError, json.loads, s) + + def test_floats(self): + for num in [1617161771.7650001, math.pi, math.pi**100, + math.pi**-100, 3.1]: + self.assertEqual(float(json.dumps(num)), num) + self.assertEqual(json.loads(json.dumps(num)), num) + self.assertEqual(json.loads(text_type(json.dumps(num))), num) + + def test_ints(self): + for num in [1, long_type(1), 1<<32, 1<<64]: + self.assertEqual(json.dumps(num), str(num)) + self.assertEqual(int(json.dumps(num)), num) + self.assertEqual(json.loads(json.dumps(num)), num) + self.assertEqual(json.loads(text_type(json.dumps(num))), num) diff --git a/lib/simplejson/tests/test_for_json.py b/lib/simplejson/tests/test_for_json.py new file mode 100644 index 0000000..b791b88 --- /dev/null +++ b/lib/simplejson/tests/test_for_json.py @@ -0,0 +1,97 @@ +import unittest +import simplejson as json + + +class ForJson(object): + def for_json(self): + return {'for_json': 1} + + +class NestedForJson(object): + def for_json(self): + return {'nested': ForJson()} + + +class ForJsonList(object): + def for_json(self): + return ['list'] + + +class DictForJson(dict): + def for_json(self): + return {'alpha': 1} + + +class ListForJson(list): + def for_json(self): + return ['list'] + + +class TestForJson(unittest.TestCase): + def assertRoundTrip(self, obj, other, for_json=True): + if for_json is None: + # None will use the default + s = json.dumps(obj) + else: + s = json.dumps(obj, for_json=for_json) + self.assertEqual( + json.loads(s), + other) + + def test_for_json_encodes_stand_alone_object(self): + self.assertRoundTrip( + ForJson(), + ForJson().for_json()) + + def test_for_json_encodes_object_nested_in_dict(self): + self.assertRoundTrip( + {'hooray': ForJson()}, + {'hooray': ForJson().for_json()}) + + def test_for_json_encodes_object_nested_in_list_within_dict(self): + self.assertRoundTrip( + {'list': [0, ForJson(), 2, 3]}, + {'list': [0, ForJson().for_json(), 2, 3]}) + + def test_for_json_encodes_object_nested_within_object(self): + self.assertRoundTrip( + NestedForJson(), + {'nested': {'for_json': 1}}) + + def test_for_json_encodes_list(self): + self.assertRoundTrip( + ForJsonList(), + ForJsonList().for_json()) + + def test_for_json_encodes_list_within_object(self): + self.assertRoundTrip( + {'nested': ForJsonList()}, + {'nested': ForJsonList().for_json()}) + + def test_for_json_encodes_dict_subclass(self): + self.assertRoundTrip( + DictForJson(a=1), + DictForJson(a=1).for_json()) + + def test_for_json_encodes_list_subclass(self): + self.assertRoundTrip( + ListForJson(['l']), + ListForJson(['l']).for_json()) + + def test_for_json_ignored_if_not_true_with_dict_subclass(self): + for for_json in (None, False): + self.assertRoundTrip( + DictForJson(a=1), + {'a': 1}, + for_json=for_json) + + def test_for_json_ignored_if_not_true_with_list_subclass(self): + for for_json in (None, False): + self.assertRoundTrip( + ListForJson(['l']), + ['l'], + for_json=for_json) + + def test_raises_typeerror_if_for_json_not_true_with_object(self): + self.assertRaises(TypeError, json.dumps, ForJson()) + self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False) diff --git a/lib/simplejson/tests/test_indent.py b/lib/simplejson/tests/test_indent.py new file mode 100644 index 0000000..cea25a5 --- /dev/null +++ b/lib/simplejson/tests/test_indent.py @@ -0,0 +1,86 @@ +from unittest import TestCase +import textwrap + +import simplejson as json +from simplejson.compat import StringIO + +class TestIndent(TestCase): + def test_indent(self): + h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', + 'i-vhbjkhnth', + {'nifty': 87}, {'field': 'yes', 'morefield': False} ] + + expect = textwrap.dedent("""\ + [ + \t[ + \t\t"blorpie" + \t], + \t[ + \t\t"whoops" + \t], + \t[], + \t"d-shtaeou", + \t"d-nthiouh", + \t"i-vhbjkhnth", + \t{ + \t\t"nifty": 87 + \t}, + \t{ + \t\t"field": "yes", + \t\t"morefield": false + \t} + ]""") + + + d1 = json.dumps(h) + d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': ')) + d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': ')) + d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': ')) + + h1 = json.loads(d1) + h2 = json.loads(d2) + h3 = json.loads(d3) + h4 = json.loads(d4) + + self.assertEqual(h1, h) + self.assertEqual(h2, h) + self.assertEqual(h3, h) + self.assertEqual(h4, h) + self.assertEqual(d3, expect.replace('\t', ' ')) + self.assertEqual(d4, expect.replace('\t', ' ')) + # NOTE: Python 2.4 textwrap.dedent converts tabs to spaces, + # so the following is expected to fail. Python 2.4 is not a + # supported platform in simplejson 2.1.0+. + self.assertEqual(d2, expect) + + def test_indent0(self): + h = {3: 1} + def check(indent, expected): + d1 = json.dumps(h, indent=indent) + self.assertEqual(d1, expected) + + sio = StringIO() + json.dump(h, sio, indent=indent) + self.assertEqual(sio.getvalue(), expected) + + # indent=0 should emit newlines + check(0, '{\n"3": 1\n}') + # indent=None is more compact + check(None, '{"3": 1}') + + def test_separators(self): + lst = [1,2,3,4] + expect = '[\n1,\n2,\n3,\n4\n]' + expect_spaces = '[\n1, \n2, \n3, \n4\n]' + # Ensure that separators still works + self.assertEqual( + expect_spaces, + json.dumps(lst, indent=0, separators=(', ', ': '))) + # Force the new defaults + self.assertEqual( + expect, + json.dumps(lst, indent=0, separators=(',', ': '))) + # Added in 2.1.4 + self.assertEqual( + expect, + json.dumps(lst, indent=0)) diff --git a/lib/simplejson/tests/test_item_sort_key.py b/lib/simplejson/tests/test_item_sort_key.py new file mode 100644 index 0000000..98971b8 --- /dev/null +++ b/lib/simplejson/tests/test_item_sort_key.py @@ -0,0 +1,27 @@ +from unittest import TestCase + +import simplejson as json +from operator import itemgetter + +class TestItemSortKey(TestCase): + def test_simple_first(self): + a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} + self.assertEqual( + '{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}', + json.dumps(a, item_sort_key=json.simple_first)) + + def test_case(self): + a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} + self.assertEqual( + '{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', + json.dumps(a, item_sort_key=itemgetter(0))) + self.assertEqual( + '{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', + json.dumps(a, item_sort_key=lambda kv: kv[0].lower())) + + def test_item_sort_key_value(self): + # https://github.com/simplejson/simplejson/issues/173 + a = {'a': 1, 'b': 0} + self.assertEqual( + '{"b": 0, "a": 1}', + json.dumps(a, item_sort_key=lambda kv: kv[1])) diff --git a/lib/simplejson/tests/test_iterable.py b/lib/simplejson/tests/test_iterable.py new file mode 100644 index 0000000..35d3e75 --- /dev/null +++ b/lib/simplejson/tests/test_iterable.py @@ -0,0 +1,31 @@ +import unittest +from simplejson.compat import StringIO + +import simplejson as json + +def iter_dumps(obj, **kw): + return ''.join(json.JSONEncoder(**kw).iterencode(obj)) + +def sio_dump(obj, **kw): + sio = StringIO() + json.dumps(obj, **kw) + return sio.getvalue() + +class TestIterable(unittest.TestCase): + def test_iterable(self): + for l in ([], [1], [1, 2], [1, 2, 3]): + for opts in [{}, {'indent': 2}]: + for dumps in (json.dumps, iter_dumps, sio_dump): + expect = dumps(l, **opts) + default_expect = dumps(sum(l), **opts) + # Default is False + self.assertRaises(TypeError, dumps, iter(l), **opts) + self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts) + self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts)) + # Ensure that the "default" gets called + self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts)) + self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts)) + # Ensure that the "default" does not get called + self.assertEqual( + expect, + dumps(iter(l), iterable_as_array=True, default=sum, **opts)) diff --git a/lib/simplejson/tests/test_namedtuple.py b/lib/simplejson/tests/test_namedtuple.py new file mode 100644 index 0000000..cc0f8aa --- /dev/null +++ b/lib/simplejson/tests/test_namedtuple.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +import unittest +import simplejson as json +from simplejson.compat import StringIO + +try: + from unittest import mock +except ImportError: + mock = None + +try: + from collections import namedtuple +except ImportError: + class Value(tuple): + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def _asdict(self): + return {'value': self[0]} + class Point(tuple): + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def _asdict(self): + return {'x': self[0], 'y': self[1]} +else: + Value = namedtuple('Value', ['value']) + Point = namedtuple('Point', ['x', 'y']) + +class DuckValue(object): + def __init__(self, *args): + self.value = Value(*args) + + def _asdict(self): + return self.value._asdict() + +class DuckPoint(object): + def __init__(self, *args): + self.point = Point(*args) + + def _asdict(self): + return self.point._asdict() + +class DeadDuck(object): + _asdict = None + +class DeadDict(dict): + _asdict = None + +CONSTRUCTORS = [ + lambda v: v, + lambda v: [v], + lambda v: [{'key': v}], +] + +class TestNamedTuple(unittest.TestCase): + def test_namedtuple_dumps(self): + for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: + d = v._asdict() + self.assertEqual(d, json.loads(json.dumps(v))) + self.assertEqual( + d, + json.loads(json.dumps(v, namedtuple_as_object=True))) + self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False))) + self.assertEqual( + d, + json.loads(json.dumps(v, namedtuple_as_object=True, + tuple_as_array=False))) + + def test_namedtuple_dumps_false(self): + for v in [Value(1), Point(1, 2)]: + l = list(v) + self.assertEqual( + l, + json.loads(json.dumps(v, namedtuple_as_object=False))) + self.assertRaises(TypeError, json.dumps, v, + tuple_as_array=False, namedtuple_as_object=False) + + def test_namedtuple_dump(self): + for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: + d = v._asdict() + sio = StringIO() + json.dump(v, sio) + self.assertEqual(d, json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=True) + self.assertEqual( + d, + json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, tuple_as_array=False) + self.assertEqual(d, json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=True, + tuple_as_array=False) + self.assertEqual( + d, + json.loads(sio.getvalue())) + + def test_namedtuple_dump_false(self): + for v in [Value(1), Point(1, 2)]: + l = list(v) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=False) + self.assertEqual( + l, + json.loads(sio.getvalue())) + self.assertRaises(TypeError, json.dump, v, StringIO(), + tuple_as_array=False, namedtuple_as_object=False) + + def test_asdict_not_callable_dump(self): + for f in CONSTRUCTORS: + self.assertRaises( + TypeError, + json.dump, + f(DeadDuck()), + StringIO(), + namedtuple_as_object=True + ) + sio = StringIO() + json.dump(f(DeadDict()), sio, namedtuple_as_object=True) + self.assertEqual( + json.dumps(f({})), + sio.getvalue()) + self.assertRaises( + TypeError, + json.dump, + f(Value), + StringIO(), + namedtuple_as_object=True + ) + + def test_asdict_not_callable_dumps(self): + for f in CONSTRUCTORS: + self.assertRaises(TypeError, + json.dumps, f(DeadDuck()), namedtuple_as_object=True) + self.assertRaises( + TypeError, + json.dumps, + f(Value), + namedtuple_as_object=True + ) + self.assertEqual( + json.dumps(f({})), + json.dumps(f(DeadDict()), namedtuple_as_object=True)) + + def test_asdict_unbound_method_dumps(self): + for f in CONSTRUCTORS: + self.assertEqual( + json.dumps(f(Value), default=lambda v: v.__name__), + json.dumps(f(Value.__name__)) + ) + + def test_asdict_does_not_return_dict(self): + if not mock: + if hasattr(unittest, "SkipTest"): + raise unittest.SkipTest("unittest.mock required") + else: + print("unittest.mock not available") + return + fake = mock.Mock() + self.assertTrue(hasattr(fake, '_asdict')) + self.assertTrue(callable(fake._asdict)) + self.assertFalse(isinstance(fake._asdict(), dict)) + # https://github.com/simplejson/simplejson/pull/284 + # when running under a debug build of CPython (COPTS=-UNDEBUG) + # a C assertion could fire due to an unchecked error of an PyDict + # API call on a non-dict internally in _speedups.c. Without a debug + # build of CPython this test likely passes either way despite the + # potential for internal data corruption. Getting it to crash in + # a debug build is not always easy either as it requires an + # assert(!PyErr_Occurred()) that could fire later on. + with self.assertRaises(TypeError): + json.dumps({23: fake}, namedtuple_as_object=True, for_json=False) diff --git a/lib/simplejson/tests/test_pass1.py b/lib/simplejson/tests/test_pass1.py new file mode 100644 index 0000000..f0b5b10 --- /dev/null +++ b/lib/simplejson/tests/test_pass1.py @@ -0,0 +1,71 @@ +from unittest import TestCase + +import simplejson as json + +# from http://json.org/JSON_checker/test/pass1.json +JSON = r''' +[ + "JSON Test Pattern pass1", + {"object with 1 member":["array with 1 element"]}, + {}, + [], + -42, + true, + false, + null, + { + "integer": 1234567890, + "real": -9876.543210, + "e": 0.123456789e-12, + "E": 1.234567890E+34, + "": 23456789012E66, + "zero": 0, + "one": 1, + "space": " ", + "quote": "\"", + "backslash": "\\", + "controls": "\b\f\n\r\t", + "slash": "/ & \/", + "alpha": "abcdefghijklmnopqrstuvwyz", + "ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ", + "digit": "0123456789", + "special": "`1~!@#$%^&*()_+-={':[,]}|;.?", + "hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A", + "true": true, + "false": false, + "null": null, + "array":[ ], + "object":{ }, + "address": "50 St. James Street", + "url": "http://www.JSON.org/", + "comment": "// /* */": " ", + " s p a c e d " :[1,2 , 3 + +, + +4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7], + "jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}", + "quotes": "" \u0022 %22 0x22 034 "", + "\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?" +: "A key can be any string" + }, + 0.5 ,98.6 +, +99.44 +, + +1066, +1e1, +0.1e1, +1e-1, +1e00,2e+00,2e-00 +,"rosebud"] +''' + +class TestPass1(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_pass2.py b/lib/simplejson/tests/test_pass2.py new file mode 100644 index 0000000..5d812b3 --- /dev/null +++ b/lib/simplejson/tests/test_pass2.py @@ -0,0 +1,14 @@ +from unittest import TestCase +import simplejson as json + +# from http://json.org/JSON_checker/test/pass2.json +JSON = r''' +[[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]] +''' + +class TestPass2(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_pass3.py b/lib/simplejson/tests/test_pass3.py new file mode 100644 index 0000000..821d60b --- /dev/null +++ b/lib/simplejson/tests/test_pass3.py @@ -0,0 +1,20 @@ +from unittest import TestCase + +import simplejson as json + +# from http://json.org/JSON_checker/test/pass3.json +JSON = r''' +{ + "JSON Test Pattern pass3": { + "The outermost value": "must be an object or array.", + "In this test": "It is an object." + } +} +''' + +class TestPass3(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_raw_json.py b/lib/simplejson/tests/test_raw_json.py new file mode 100644 index 0000000..1dfcc2c --- /dev/null +++ b/lib/simplejson/tests/test_raw_json.py @@ -0,0 +1,47 @@ +import unittest +import simplejson as json + +dct1 = { + 'key1': 'value1' +} + +dct2 = { + 'key2': 'value2', + 'd1': dct1 +} + +dct3 = { + 'key2': 'value2', + 'd1': json.dumps(dct1) +} + +dct4 = { + 'key2': 'value2', + 'd1': json.RawJSON(json.dumps(dct1)) +} + + +class TestRawJson(unittest.TestCase): + + def test_normal_str(self): + self.assertNotEqual(json.dumps(dct2), json.dumps(dct3)) + + def test_raw_json_str(self): + self.assertEqual(json.dumps(dct2), json.dumps(dct4)) + self.assertEqual(dct2, json.loads(json.dumps(dct4))) + + def test_list(self): + self.assertEqual( + json.dumps([dct2]), + json.dumps([json.RawJSON(json.dumps(dct2))])) + self.assertEqual( + [dct2], + json.loads(json.dumps([json.RawJSON(json.dumps(dct2))]))) + + def test_direct(self): + self.assertEqual( + json.dumps(dct2), + json.dumps(json.RawJSON(json.dumps(dct2)))) + self.assertEqual( + dct2, + json.loads(json.dumps(json.RawJSON(json.dumps(dct2))))) diff --git a/lib/simplejson/tests/test_recursion.py b/lib/simplejson/tests/test_recursion.py new file mode 100644 index 0000000..662eb66 --- /dev/null +++ b/lib/simplejson/tests/test_recursion.py @@ -0,0 +1,67 @@ +from unittest import TestCase + +import simplejson as json + +class JSONTestObject: + pass + + +class RecursiveJSONEncoder(json.JSONEncoder): + recurse = False + def default(self, o): + if o is JSONTestObject: + if self.recurse: + return [JSONTestObject] + else: + return 'JSONTestObject' + return json.JSONEncoder.default(o) + + +class TestRecursion(TestCase): + def test_listrecursion(self): + x = [] + x.append(x) + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on list recursion") + x = [] + y = [x] + x.append(y) + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on alternating list recursion") + y = [] + x = [y, y] + # ensure that the marker is cleared + json.dumps(x) + + def test_dictrecursion(self): + x = {} + x["test"] = x + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on dict recursion") + x = {} + y = {"a": x, "b": x} + # ensure that the marker is cleared + json.dumps(y) + + def test_defaultrecursion(self): + enc = RecursiveJSONEncoder() + self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"') + enc.recurse = True + try: + enc.encode(JSONTestObject) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on default recursion") diff --git a/lib/simplejson/tests/test_scanstring.py b/lib/simplejson/tests/test_scanstring.py new file mode 100644 index 0000000..1f54483 --- /dev/null +++ b/lib/simplejson/tests/test_scanstring.py @@ -0,0 +1,200 @@ +import sys +from unittest import TestCase + +import simplejson as json +import simplejson.decoder +from simplejson.compat import b, PY3 + +class TestScanString(TestCase): + # The bytes type is intentionally not used in most of these tests + # under Python 3 because the decoder immediately coerces to str before + # calling scanstring. In Python 2 we are testing the code paths + # for both unicode and str. + # + # The reason this is done is because Python 3 would require + # entirely different code paths for parsing bytes and str. + # + def test_py_scanstring(self): + self._test_scanstring(simplejson.decoder.py_scanstring) + + def test_c_scanstring(self): + if not simplejson.decoder.c_scanstring: + return + self._test_scanstring(simplejson.decoder.c_scanstring) + + self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str)) + + def _test_scanstring(self, scanstring): + if sys.maxunicode == 65535: + self.assertEqual( + scanstring(u'"z\U0001d120x"', 1, None, True), + (u'z\U0001d120x', 6)) + else: + self.assertEqual( + scanstring(u'"z\U0001d120x"', 1, None, True), + (u'z\U0001d120x', 5)) + + self.assertEqual( + scanstring('"\\u007b"', 1, None, True), + (u'{', 8)) + + self.assertEqual( + scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True), + (u'A JSON payload should be an object or array, not a string.', 60)) + + self.assertEqual( + scanstring('["Unclosed array"', 2, None, True), + (u'Unclosed array', 17)) + + self.assertEqual( + scanstring('["extra comma",]', 2, None, True), + (u'extra comma', 14)) + + self.assertEqual( + scanstring('["double extra comma",,]', 2, None, True), + (u'double extra comma', 21)) + + self.assertEqual( + scanstring('["Comma after the close"],', 2, None, True), + (u'Comma after the close', 24)) + + self.assertEqual( + scanstring('["Extra close"]]', 2, None, True), + (u'Extra close', 14)) + + self.assertEqual( + scanstring('{"Extra comma": true,}', 2, None, True), + (u'Extra comma', 14)) + + self.assertEqual( + scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True), + (u'Extra value after close', 26)) + + self.assertEqual( + scanstring('{"Illegal expression": 1 + 2}', 2, None, True), + (u'Illegal expression', 21)) + + self.assertEqual( + scanstring('{"Illegal invocation": alert()}', 2, None, True), + (u'Illegal invocation', 21)) + + self.assertEqual( + scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True), + (u'Numbers cannot have leading zeroes', 37)) + + self.assertEqual( + scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True), + (u'Numbers cannot be hex', 24)) + + self.assertEqual( + scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True), + (u'Too deep', 30)) + + self.assertEqual( + scanstring('{"Missing colon" null}', 2, None, True), + (u'Missing colon', 16)) + + self.assertEqual( + scanstring('{"Double colon":: null}', 2, None, True), + (u'Double colon', 15)) + + self.assertEqual( + scanstring('{"Comma instead of colon", null}', 2, None, True), + (u'Comma instead of colon', 25)) + + self.assertEqual( + scanstring('["Colon instead of comma": false]', 2, None, True), + (u'Colon instead of comma', 25)) + + self.assertEqual( + scanstring('["Bad value", truth]', 2, None, True), + (u'Bad value', 12)) + + for c in map(chr, range(0x00, 0x1f)): + self.assertEqual( + scanstring(c + '"', 0, None, False), + (c, 2)) + self.assertRaises( + ValueError, + scanstring, c + '"', 0, None, True) + + self.assertRaises(ValueError, scanstring, '', 0, None, True) + self.assertRaises(ValueError, scanstring, 'a', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True) + if sys.maxunicode > 65535: + self.assertRaises(ValueError, + scanstring, '\\ud834\\u"', 0, None, True) + self.assertRaises(ValueError, + scanstring, '\\ud834\\x0123"', 0, None, True) + + self.assertRaises(json.JSONDecodeError, scanstring, '\\u-123"', 0, None, True) + # SJ-PT-23-01: Invalid Handling of Broken Unicode Escape Sequences + self.assertRaises(json.JSONDecodeError, scanstring, '\\u EDD"', 0, None, True) + + def test_issue3623(self): + self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1, + "xxx") + self.assertRaises(UnicodeDecodeError, + json.encoder.encode_basestring_ascii, b("xx\xff")) + + def test_overflow(self): + # Python 2.5 does not have maxsize, Python 3 does not have maxint + maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None)) + assert maxsize is not None + self.assertRaises(OverflowError, json.decoder.scanstring, "xxx", + maxsize + 1) + + def test_surrogates(self): + scanstring = json.decoder.scanstring + + def assertScan(given, expect, test_utf8=True): + givens = [given] + if not PY3 and test_utf8: + givens.append(given.encode('utf8')) + for given in givens: + (res, count) = scanstring(given, 1, None, True) + self.assertEqual(len(given), count) + self.assertEqual(res, expect) + + assertScan( + u'"z\\ud834\\u0079x"', + u'z\ud834yx') + assertScan( + u'"z\\ud834\\udd20x"', + u'z\U0001d120x') + assertScan( + u'"z\\ud834\\ud834\\udd20x"', + u'z\ud834\U0001d120x') + assertScan( + u'"z\\ud834x"', + u'z\ud834x') + assertScan( + u'"z\\udd20x"', + u'z\udd20x') + assertScan( + u'"z\ud834x"', + u'z\ud834x') + # It may look strange to join strings together, but Python is drunk. + # https://gist.github.com/etrepum/5538443 + assertScan( + u'"z\\ud834\udd20x12345"', + u''.join([u'z\ud834', u'\udd20x12345'])) + assertScan( + u'"z\ud834\\udd20x"', + u''.join([u'z\ud834', u'\udd20x'])) + # these have different behavior given UTF8 input, because the surrogate + # pair may be joined (in maxunicode > 65535 builds) + assertScan( + u''.join([u'"z\ud834', u'\udd20x"']), + u''.join([u'z\ud834', u'\udd20x']), + test_utf8=False) + + self.assertRaises(ValueError, + scanstring, u'"z\\ud83x"', 1, None, True) + self.assertRaises(ValueError, + scanstring, u'"z\\ud834\\udd2x"', 1, None, True) diff --git a/lib/simplejson/tests/test_separators.py b/lib/simplejson/tests/test_separators.py new file mode 100644 index 0000000..91b4d4f --- /dev/null +++ b/lib/simplejson/tests/test_separators.py @@ -0,0 +1,42 @@ +import textwrap +from unittest import TestCase + +import simplejson as json + + +class TestSeparators(TestCase): + def test_separators(self): + h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth', + {'nifty': 87}, {'field': 'yes', 'morefield': False} ] + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ] , + [ + "whoops" + ] , + [] , + "d-shtaeou" , + "d-nthiouh" , + "i-vhbjkhnth" , + { + "nifty" : 87 + } , + { + "field" : "yes" , + "morefield" : false + } + ]""") + + + d1 = json.dumps(h) + d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : ')) + + h1 = json.loads(d1) + h2 = json.loads(d2) + + self.assertEqual(h1, h) + self.assertEqual(h2, h) + self.assertEqual(d2, expect) diff --git a/lib/simplejson/tests/test_speedups.py b/lib/simplejson/tests/test_speedups.py new file mode 100644 index 0000000..8b146df --- /dev/null +++ b/lib/simplejson/tests/test_speedups.py @@ -0,0 +1,114 @@ +from __future__ import with_statement + +import sys +import unittest +from unittest import TestCase + +import simplejson +from simplejson import encoder, decoder, scanner +from simplejson.compat import PY3, long_type, b + + +def has_speedups(): + return encoder.c_make_encoder is not None + + +def skip_if_speedups_missing(func): + def wrapper(*args, **kwargs): + if not has_speedups(): + if hasattr(unittest, 'SkipTest'): + raise unittest.SkipTest("C Extension not available") + else: + sys.stdout.write("C Extension not available") + return + return func(*args, **kwargs) + + return wrapper + + +class BadBool: + def __bool__(self): + 1/0 + __nonzero__ = __bool__ + + +class TestDecode(TestCase): + @skip_if_speedups_missing + def test_make_scanner(self): + self.assertRaises(AttributeError, scanner.c_make_scanner, 1) + + @skip_if_speedups_missing + def test_bad_bool_args(self): + def test(value): + decoder.JSONDecoder(strict=BadBool()).decode(value) + self.assertRaises(ZeroDivisionError, test, '""') + self.assertRaises(ZeroDivisionError, test, '{}') + if not PY3: + self.assertRaises(ZeroDivisionError, test, u'""') + self.assertRaises(ZeroDivisionError, test, u'{}') + +class TestEncode(TestCase): + @skip_if_speedups_missing + def test_make_encoder(self): + self.assertRaises( + TypeError, + encoder.c_make_encoder, + None, + ("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7" + "\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"), + None + ) + + @skip_if_speedups_missing + def test_bad_str_encoder(self): + # Issue #31505: There shouldn't be an assertion failure in case + # c_make_encoder() receives a bad encoder() argument. + import decimal + def bad_encoder1(*args): + return None + enc = encoder.c_make_encoder( + None, lambda obj: str(obj), + bad_encoder1, None, ': ', ', ', + False, False, False, {}, False, False, False, + None, None, 'utf-8', False, False, decimal.Decimal, False) + self.assertRaises(TypeError, enc, 'spam', 4) + self.assertRaises(TypeError, enc, {'spam': 42}, 4) + + def bad_encoder2(*args): + 1/0 + enc = encoder.c_make_encoder( + None, lambda obj: str(obj), + bad_encoder2, None, ': ', ', ', + False, False, False, {}, False, False, False, + None, None, 'utf-8', False, False, decimal.Decimal, False) + self.assertRaises(ZeroDivisionError, enc, 'spam', 4) + + @skip_if_speedups_missing + def test_bad_bool_args(self): + def test(name): + encoder.JSONEncoder(**{name: BadBool()}).encode({}) + self.assertRaises(ZeroDivisionError, test, 'skipkeys') + self.assertRaises(ZeroDivisionError, test, 'ensure_ascii') + self.assertRaises(ZeroDivisionError, test, 'check_circular') + self.assertRaises(ZeroDivisionError, test, 'allow_nan') + self.assertRaises(ZeroDivisionError, test, 'sort_keys') + self.assertRaises(ZeroDivisionError, test, 'use_decimal') + self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object') + self.assertRaises(ZeroDivisionError, test, 'tuple_as_array') + self.assertRaises(ZeroDivisionError, test, 'bigint_as_string') + self.assertRaises(ZeroDivisionError, test, 'for_json') + self.assertRaises(ZeroDivisionError, test, 'ignore_nan') + self.assertRaises(ZeroDivisionError, test, 'iterable_as_array') + + @skip_if_speedups_missing + def test_int_as_string_bitcount_overflow(self): + long_count = long_type(2)**32+31 + def test(): + encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0) + self.assertRaises((TypeError, OverflowError), test) + + if PY3: + @skip_if_speedups_missing + def test_bad_encoding(self): + with self.assertRaises(UnicodeEncodeError): + encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123}) diff --git a/lib/simplejson/tests/test_str_subclass.py b/lib/simplejson/tests/test_str_subclass.py new file mode 100644 index 0000000..b6c8351 --- /dev/null +++ b/lib/simplejson/tests/test_str_subclass.py @@ -0,0 +1,21 @@ +from unittest import TestCase + +import simplejson +from simplejson.compat import text_type + +# Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144 +class WonkyTextSubclass(text_type): + def __getslice__(self, start, end): + return self.__class__('not what you wanted!') + +class TestStrSubclass(TestCase): + def test_dump_load(self): + for s in ['', '"hello"', 'text', u'\u005c']: + self.assertEqual( + s, + simplejson.loads(simplejson.dumps(WonkyTextSubclass(s)))) + + self.assertEqual( + s, + simplejson.loads(simplejson.dumps(WonkyTextSubclass(s), + ensure_ascii=False))) diff --git a/lib/simplejson/tests/test_subclass.py b/lib/simplejson/tests/test_subclass.py new file mode 100644 index 0000000..2bae3b6 --- /dev/null +++ b/lib/simplejson/tests/test_subclass.py @@ -0,0 +1,37 @@ +from unittest import TestCase +import simplejson as json + +from decimal import Decimal + +class AlternateInt(int): + def __repr__(self): + return 'invalid json' + __str__ = __repr__ + + +class AlternateFloat(float): + def __repr__(self): + return 'invalid json' + __str__ = __repr__ + + +# class AlternateDecimal(Decimal): +# def __repr__(self): +# return 'invalid json' + + +class TestSubclass(TestCase): + def test_int(self): + self.assertEqual(json.dumps(AlternateInt(1)), '1') + self.assertEqual(json.dumps(AlternateInt(-1)), '-1') + self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1}) + + def test_float(self): + self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0') + self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0') + self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1}) + + # NOTE: Decimal subclasses are not supported as-is + # def test_decimal(self): + # self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0') + # self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0') diff --git a/lib/simplejson/tests/test_tool.py b/lib/simplejson/tests/test_tool.py new file mode 100644 index 0000000..914bff8 --- /dev/null +++ b/lib/simplejson/tests/test_tool.py @@ -0,0 +1,114 @@ +from __future__ import with_statement +import os +import sys +import textwrap +import unittest +import subprocess +import tempfile +try: + # Python 3.x + from test.support import strip_python_stderr +except ImportError: + # Python 2.6+ + try: + from test.test_support import strip_python_stderr + except ImportError: + # Python 2.5 + import re + def strip_python_stderr(stderr): + return re.sub( + r"\[\d+ refs\]\r?\n?$".encode(), + "".encode(), + stderr).strip() + +def open_temp_file(): + if sys.version_info >= (2, 6): + file = tempfile.NamedTemporaryFile(delete=False) + filename = file.name + else: + fd, filename = tempfile.mkstemp() + file = os.fdopen(fd, 'w+b') + return file, filename + +class TestTool(unittest.TestCase): + data = """ + + [["blorpie"],[ "whoops" ] , [ + ],\t"d-shtaeou",\r"d-nthiouh", + "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" + :"yes"} ] + """ + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ], + [ + "whoops" + ], + [], + "d-shtaeou", + "d-nthiouh", + "i-vhbjkhnth", + { + "nifty": 87 + }, + { + "field": "yes", + "morefield": false + } + ] + """) + + def runTool(self, args=None, data=None): + argv = [sys.executable, '-m', 'simplejson.tool'] + if args: + argv.extend(args) + proc = subprocess.Popen(argv, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + out, err = proc.communicate(data) + self.assertEqual(strip_python_stderr(err), ''.encode()) + self.assertEqual(proc.returncode, 0) + return out.decode('utf8').splitlines() + + def test_stdin_stdout(self): + self.assertEqual( + self.runTool(data=self.data.encode()), + self.expect.splitlines()) + + def test_infile_stdout(self): + infile, infile_name = open_temp_file() + try: + infile.write(self.data.encode()) + infile.close() + self.assertEqual( + self.runTool(args=[infile_name]), + self.expect.splitlines()) + finally: + os.unlink(infile_name) + + def test_infile_outfile(self): + infile, infile_name = open_temp_file() + try: + infile.write(self.data.encode()) + infile.close() + # outfile will get overwritten by tool, so the delete + # may not work on some platforms. Do it manually. + outfile, outfile_name = open_temp_file() + try: + outfile.close() + self.assertEqual( + self.runTool(args=[infile_name, outfile_name]), + []) + with open(outfile_name, 'rb') as f: + self.assertEqual( + f.read().decode('utf8').splitlines(), + self.expect.splitlines() + ) + finally: + os.unlink(outfile_name) + finally: + os.unlink(infile_name) diff --git a/lib/simplejson/tests/test_tuple.py b/lib/simplejson/tests/test_tuple.py new file mode 100644 index 0000000..4ad7b0e --- /dev/null +++ b/lib/simplejson/tests/test_tuple.py @@ -0,0 +1,47 @@ +import unittest + +from simplejson.compat import StringIO +import simplejson as json + +class TestTuples(unittest.TestCase): + def test_tuple_array_dumps(self): + t = (1, 2, 3) + expect = json.dumps(list(t)) + # Default is True + self.assertEqual(expect, json.dumps(t)) + self.assertEqual(expect, json.dumps(t, tuple_as_array=True)) + self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False) + # Ensure that the "default" does not get called + self.assertEqual(expect, json.dumps(t, default=repr)) + self.assertEqual(expect, json.dumps(t, tuple_as_array=True, + default=repr)) + # Ensure that the "default" gets called + self.assertEqual( + json.dumps(repr(t)), + json.dumps(t, tuple_as_array=False, default=repr)) + + def test_tuple_array_dump(self): + t = (1, 2, 3) + expect = json.dumps(list(t)) + # Default is True + sio = StringIO() + json.dump(t, sio) + self.assertEqual(expect, sio.getvalue()) + sio = StringIO() + json.dump(t, sio, tuple_as_array=True) + self.assertEqual(expect, sio.getvalue()) + self.assertRaises(TypeError, json.dump, t, StringIO(), + tuple_as_array=False) + # Ensure that the "default" does not get called + sio = StringIO() + json.dump(t, sio, default=repr) + self.assertEqual(expect, sio.getvalue()) + sio = StringIO() + json.dump(t, sio, tuple_as_array=True, default=repr) + self.assertEqual(expect, sio.getvalue()) + # Ensure that the "default" gets called + sio = StringIO() + json.dump(t, sio, tuple_as_array=False, default=repr) + self.assertEqual( + json.dumps(repr(t)), + sio.getvalue()) diff --git a/lib/simplejson/tests/test_unicode.py b/lib/simplejson/tests/test_unicode.py new file mode 100644 index 0000000..0c7b1a6 --- /dev/null +++ b/lib/simplejson/tests/test_unicode.py @@ -0,0 +1,154 @@ +import sys +import codecs +from unittest import TestCase + +import simplejson as json +from simplejson.compat import unichr, text_type, b, BytesIO + +class TestUnicode(TestCase): + def test_encoding1(self): + encoder = json.JSONEncoder(encoding='utf-8') + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + s = u.encode('utf-8') + ju = encoder.encode(u) + js = encoder.encode(s) + self.assertEqual(ju, js) + + def test_encoding2(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + s = u.encode('utf-8') + ju = json.dumps(u, encoding='utf-8') + js = json.dumps(s, encoding='utf-8') + self.assertEqual(ju, js) + + def test_encoding3(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps(u) + self.assertEqual(j, '"\\u03b1\\u03a9"') + + def test_encoding4(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps([u]) + self.assertEqual(j, '["\\u03b1\\u03a9"]') + + def test_encoding5(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps(u, ensure_ascii=False) + self.assertEqual(j, u'"' + u + u'"') + + def test_encoding6(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps([u], ensure_ascii=False) + self.assertEqual(j, u'["' + u + u'"]') + + def test_big_unicode_encode(self): + u = u'\U0001d120' + self.assertEqual(json.dumps(u), '"\\ud834\\udd20"') + self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"') + + def test_big_unicode_decode(self): + u = u'z\U0001d120x' + self.assertEqual(json.loads('"' + u + '"'), u) + self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u) + + def test_unicode_decode(self): + for i in range(0, 0xd7ff): + u = unichr(i) + #s = '"\\u{0:04x}"'.format(i) + s = '"\\u%04x"' % (i,) + self.assertEqual(json.loads(s), u) + + def test_object_pairs_hook_with_unicode(self): + s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}' + p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4), + (u"qrt", 5), (u"pad", 6), (u"hoy", 7)] + self.assertEqual(json.loads(s), eval(s)) + self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p) + od = json.loads(s, object_pairs_hook=json.OrderedDict) + self.assertEqual(od, json.OrderedDict(p)) + self.assertEqual(type(od), json.OrderedDict) + # the object_pairs_hook takes priority over the object_hook + self.assertEqual(json.loads(s, + object_pairs_hook=json.OrderedDict, + object_hook=lambda x: None), + json.OrderedDict(p)) + + + def test_default_encoding(self): + self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')), + {'a': u'\xe9'}) + + def test_unicode_preservation(self): + self.assertEqual(type(json.loads(u'""')), text_type) + self.assertEqual(type(json.loads(u'"a"')), text_type) + self.assertEqual(type(json.loads(u'["a"]')[0]), text_type) + + def test_ensure_ascii_false_returns_unicode(self): + # http://code.google.com/p/simplejson/issues/detail?id=48 + self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type) + + def test_ensure_ascii_false_bytestring_encoding(self): + # http://code.google.com/p/simplejson/issues/detail?id=48 + doc1 = {u'quux': b('Arr\xc3\xaat sur images')} + doc2 = {u'quux': u'Arr\xeat sur images'} + doc_ascii = '{"quux": "Arr\\u00eat sur images"}' + doc_unicode = u'{"quux": "Arr\xeat sur images"}' + self.assertEqual(json.dumps(doc1), doc_ascii) + self.assertEqual(json.dumps(doc2), doc_ascii) + self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode) + self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode) + + def test_ensure_ascii_linebreak_encoding(self): + # http://timelessrepo.com/json-isnt-a-javascript-subset + s1 = u'\u2029\u2028' + s2 = s1.encode('utf8') + expect = '"\\u2029\\u2028"' + expect_non_ascii = u'"\u2029\u2028"' + self.assertEqual(json.dumps(s1), expect) + self.assertEqual(json.dumps(s2), expect) + self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii) + self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii) + + def test_invalid_escape_sequences(self): + # incomplete escape sequence + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234') + # invalid escape sequence + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"') + if sys.maxunicode > 65535: + # invalid escape sequence for low surrogate + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"') + + def test_ensure_ascii_still_works(self): + # in the ascii range, ensure that everything is the same + for c in map(unichr, range(0, 127)): + self.assertEqual( + json.dumps(c, ensure_ascii=False), + json.dumps(c)) + snowman = u'\N{SNOWMAN}' + self.assertEqual( + json.dumps(c, ensure_ascii=False), + '"' + c + '"') + + def test_strip_bom(self): + content = u"\u3053\u3093\u306b\u3061\u308f" + json_doc = codecs.BOM_UTF8 + b(json.dumps(content)) + self.assertEqual(json.load(BytesIO(json_doc)), content) + for doc in json_doc, json_doc.decode('utf8'): + self.assertEqual(json.loads(doc), content) diff --git a/lib/simplejson/tool.py b/lib/simplejson/tool.py new file mode 100644 index 0000000..062e8e2 --- /dev/null +++ b/lib/simplejson/tool.py @@ -0,0 +1,42 @@ +r"""Command-line tool to validate and pretty-print JSON + +Usage:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 2 (char 2) + +""" +from __future__ import with_statement +import sys +import simplejson as json + +def main(): + if len(sys.argv) == 1: + infile = sys.stdin + outfile = sys.stdout + elif len(sys.argv) == 2: + infile = open(sys.argv[1], 'r') + outfile = sys.stdout + elif len(sys.argv) == 3: + infile = open(sys.argv[1], 'r') + outfile = open(sys.argv[2], 'w') + else: + raise SystemExit(sys.argv[0] + " [infile [outfile]]") + with infile: + try: + obj = json.load(infile, + object_pairs_hook=json.OrderedDict, + use_decimal=True) + except ValueError: + raise SystemExit(sys.exc_info()[1]) + with outfile: + json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) + outfile.write('\n') + + +if __name__ == '__main__': + main() diff --git a/test_func/save_targer_keys.py b/test_func/save_targer_keys.py new file mode 100644 index 0000000..b2d7545 --- /dev/null +++ b/test_func/save_targer_keys.py @@ -0,0 +1,108 @@ +import os +import sys +import json +import torch +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from thop import profile +from ptflops import get_model_complexity_info + +import artist.data as data +from tools.modules.config import cfg +from tools.modules.unet.util import * +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, MODEL + + +def save_temporal_key(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + model = MODEL.build(cfg.UNet) + + temp_name = '' + temp_key_list = [] + spth = 'workspace/module_list/UNetSD_I2V_vs_Text_temporal_key_list.json' + for name, module in model.named_modules(): + if isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)): + temp_name = name + print(f'Model: {name}') + elif isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)): + temp_name = '' + + if hasattr(module, 'weight'): + if temp_name != '' and (temp_name in name): + temp_key_list.append(name) + print(f'{name}') + # print(name) + + save_module_list = [] + for k, p in model.named_parameters(): + for item in temp_key_list: + if item in k: + print(f'{item} --> {k}') + save_module_list.append(k) + + print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters') + + # spth = 'workspace/module_list/{}' + json.dump(save_module_list, open(spth, 'w')) + a = 0 + + +def save_spatial_key(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + model = MODEL.build(cfg.UNet) + temp_name = '' + temp_key_list = [] + spth = 'workspace/module_list/UNetSD_I2V_HQ_P_spatial_key_list.json' + for name, module in model.named_modules(): + if isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)): + temp_name = name + print(f'Model: {name}') + elif isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)): + temp_name = '' + + if hasattr(module, 'weight'): + if temp_name != '' and (temp_name in name): + temp_key_list.append(name) + print(f'{name}') + # print(name) + + save_module_list = [] + for k, p in model.named_parameters(): + for item in temp_key_list: + if item in k: + print(f'{item} --> {k}') + save_module_list.append(k) + + print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters') + + # spth = 'workspace/module_list/{}' + json.dump(save_module_list, open(spth, 'w')) + a = 0 + + +if __name__ == '__main__': + # save_temporal_key() + save_spatial_key() + + + +# print([k for (k, _) in self.input_blocks.named_parameters()]) + + diff --git a/test_func/test_EndDec.py b/test_func/test_EndDec.py new file mode 100644 index 0000000..0f256b6 --- /dev/null +++ b/test_func/test_EndDec.py @@ -0,0 +1,95 @@ +import os +import sys +import torch +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from PIL import Image, ImageDraw, ImageFont + +from einops import rearrange + +from tools import * +import utils.transforms as data +from utils.seed import setup_seed +from tools.modules.config import cfg +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER + + +def test_enc_dec(gpu=0): + setup_seed(0) + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type']) + os.system('rm -rf %s' % (save_dir)) + os.makedirs(save_dir, exist_ok=True) + + train_trans = data.Compose([ + data.CenterCropWide(size=cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + + vit_trans = data.Compose([ + data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution), + data.Resize(cfg.vit_resolution), + data.ToTensor(), + data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w + video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w + + txt_size = cfg.resolution[1] + nc = int(38 * (txt_size / 256)) + font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13) + + dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans) + print('There are %d videos' % (len(dataset))) + + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.to(gpu) + for idx, item in enumerate(dataset): + local_path = os.path.join(save_dir, '%04d.mp4' % idx) + # ref_frame, video_data, caption = item + ref_frame, vit_frame, video_data = item[:3] + video_data = video_data.to(gpu) + + image_list = [] + video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0) + with torch.no_grad(): + decode_data = [] + for chunk_data in video_data_list: + latent_z = autoencoder.encode_firsr_stage(chunk_data).detach() + # latent_z = get_first_stage_encoding(encoder_posterior).detach() + kwargs = {"timesteps": chunk_data.shape[0]} + recons_data = autoencoder.decode(latent_z, **kwargs) + + vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu() + vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384 + vis_data = vis_data.cpu() + vis_data.clamp_(0, 1) + vis_data = vis_data.permute(0, 2, 3, 1) + vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data] + image_list.extend(vis_data) + + num_image = len(image_list) + frame_dir = os.path.join(save_dir, 'temp') + os.makedirs(frame_dir, exist_ok=True) + for idx in range(num_image): + tpth = os.path.join(frame_dir, '%04d.png' % (idx+1)) + cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd); os.system(f'rm -rf {frame_dir}') + + +if __name__ == '__main__': + test_enc_dec() diff --git a/test_func/test_dataset.py b/test_func/test_dataset.py new file mode 100644 index 0000000..33f2a35 --- /dev/null +++ b/test_func/test_dataset.py @@ -0,0 +1,152 @@ +import os +import sys +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from PIL import Image, ImageDraw, ImageFont +import torchvision.transforms as T + +import utils.transforms as data +from tools.modules.config import cfg +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, DATASETS + +from tools import * + +def test_video_dataset(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + exp_name = os.path.basename(cfg.cfg_file).split('.')[0] + save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name) + os.system('rm -rf %s' % (save_dir)) + os.makedirs(save_dir, exist_ok=True) + + train_trans = data.Compose([ + data.CenterCropWide(size=cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + vit_trans = T.Compose([ + data.CenterCropWide(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w + video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w + + img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w + img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w + + vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w + vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w + + txt_size = cfg.resolution[1] + nc = int(38 * (txt_size / 256)) + font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13) + + dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans) + print('There are %d videos' % (len(dataset))) + for idx, item in enumerate(dataset): + ref_frame, vit_frame, video_data, caption, video_key = item + + video_data = video_data.mul_(video_std).add_(video_mean) + video_data.clamp_(0, 1) + video_data = video_data.permute(0, 2, 3, 1) + video_data = [(image.numpy() * 255).astype('uint8') for image in video_data] + + # Single Image + ref_frame = ref_frame.mul_(img_mean).add_(img_std) + ref_frame.clamp_(0, 1) + ref_frame = ref_frame.permute(1, 2, 0) + ref_frame = (ref_frame.numpy() * 255).astype('uint8') + + # Text image + txt_img = Image.new("RGB", (txt_size, txt_size), color="white") + draw = ImageDraw.Draw(txt_img) + lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc)) + draw.text((0, 0), lines, fill="black", font=font) + txt_img = np.array(txt_img) + + video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data] + spath = os.path.join(save_dir, '%04d.gif' % (idx)) + imageio.mimwrite(spath, video_data, fps =8) + + # if idx > 100: break + + +def test_vit_image(test_video_flag=True): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + exp_name = os.path.basename(cfg.cfg_file).split('.')[0] + save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name) + os.system('rm -rf %s' % (save_dir)) + os.makedirs(save_dir, exist_ok=True) + + train_trans = data.Compose([ + data.CenterCropWide(size=cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + vit_trans = data.Compose([ + data.CenterCropWide(cfg.resolution), + data.Resize(cfg.vit_resolution), + data.ToTensor(), + data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w + img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w + + vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w + vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w + + txt_size = cfg.resolution[1] + nc = int(38 * (txt_size / 256)) + font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13) + + dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans) + print('There are %d videos' % (len(dataset))) + for idx, item in enumerate(dataset): + ref_frame, vit_frame, video_data, caption, video_key = item + video_data = video_data.mul_(img_std).add_(img_mean) + video_data.clamp_(0, 1) + video_data = video_data.permute(0, 2, 3, 1) + video_data = [(image.numpy() * 255).astype('uint8') for image in video_data] + + # Single Image + vit_frame = vit_frame.mul_(vit_std).add_(vit_mean) + vit_frame.clamp_(0, 1) + vit_frame = vit_frame.permute(1, 2, 0) + vit_frame = (vit_frame.numpy() * 255).astype('uint8') + + zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8) + zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame + + # Text image + txt_img = Image.new("RGB", (txt_size, txt_size), color="white") + draw = ImageDraw.Draw(txt_img) + lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc)) + draw.text((0, 0), lines, fill="black", font=font) + txt_img = np.array(txt_img) + + video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data] + spath = os.path.join(save_dir, '%04d.gif' % (idx)) + imageio.mimwrite(spath, video_data, fps =8) + + # if idx > 100: break + + +if __name__ == '__main__': + # test_video_dataset() + test_vit_image() + diff --git a/test_func/test_models.py b/test_func/test_models.py new file mode 100644 index 0000000..aa7f78f --- /dev/null +++ b/test_func/test_models.py @@ -0,0 +1,56 @@ +import os +import sys +import torch +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from thop import profile +from ptflops import get_model_complexity_info + +import artist.data as data +from tools.modules.config import cfg +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, MODEL + + +def test_model(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + model = MODEL.build(cfg.UNet) + print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters') + + # state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu') + # model.load_state_dict(state_dict, strict=False) + model = model.cuda() + + x = torch.Tensor(1, 4, 16, 32, 56).cuda() + t = torch.Tensor(1).cuda() + sims = torch.Tensor(1, 32).cuda() + fps = torch.Tensor([8]).cuda() + y = torch.Tensor(1, 1, 1024).cuda() + image = torch.Tensor(1, 3, 256, 448).cuda() + + ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps) + print('Out shape if {}'.format(ret.shape)) + + # flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps)) + # print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6)) + + def prepare_input(resolution): + return dict(x=[x, t, y, image, sims, fps]) + + flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56), + input_constructor = prepare_input, + as_strings=True, print_per_layer_stat=True) + print(' - Flops: ' + flops) + print(' - Params: ' + params) + +if __name__ == '__main__': + test_model() diff --git a/test_func/test_save_video.py b/test_func/test_save_video.py new file mode 100644 index 0000000..926308f --- /dev/null +++ b/test_func/test_save_video.py @@ -0,0 +1,24 @@ +import numpy as np +import cv2 + +cap = cv2.VideoCapture('workspace/img_dir/tst.mp4') + +fourcc = cv2.VideoWriter_fourcc(*'H264') + +ret, frame = cap.read() +vid_size = frame.shape[:2][::-1] + +out = cv2.VideoWriter('workspace/img_dir/testwrite.mp4',fourcc, 8, vid_size) +out.write(frame) + +while(cap.isOpened()): + ret, frame = cap.read() + if not ret: break + out.write(frame) + + +cap.release() +out.release() + + +