diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9cb9e89 Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/run_align_pose.cpython-310.pyc b/__pycache__/run_align_pose.cpython-310.pyc new file mode 100644 index 0000000..147da1b Binary files /dev/null and b/__pycache__/run_align_pose.cpython-310.pyc differ diff --git a/__pycache__/uniAnimate_Inference.cpython-310.pyc b/__pycache__/uniAnimate_Inference.cpython-310.pyc new file mode 100644 index 0000000..a67a795 Binary files /dev/null and b/__pycache__/uniAnimate_Inference.cpython-310.pyc differ diff --git a/configs/UniAnimate_infer.yaml b/configs/UniAnimate_infer.yaml new file mode 100644 index 0000000..d44fc89 --- /dev/null +++ b/configs/UniAnimate_infer.yaml @@ -0,0 +1,101 @@ +# manual setting +max_frames: 1 +resolution: [512, 768] # or resolution: [768, 1216] +# resolution: [768, 1216] +round: 1 +ddim_timesteps: 30 # among 25-50 +seed: 11 # 7 +test_list_path: [ + # Format: [frame_interval, reference image, driving pose sequence] + # [2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"], + # [2, "data/images/musk.jpg", "data/saved_pose/musk"], + # [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"], + # [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"] + [1, "data/images/i.png", "data/saved_pose/dancePoses"] + +] +partial_keys: [ + # ['image','local_image', "dwpose"], # reference image as the first frame of the generated video (optional) + ['image', 'randomref', "dwpose"], + ] + + + + +# default settings +TASK_TYPE: inference_unianimate_entrance +use_fp16: True +guide_scale: 2.5 +vit_resolution: [224, 224] +use_fp16: True +batch_size: 1 +latent_random_ref: True +chunk_size: 2 +decoder_bs: 2 +scale: 8 +use_fps_condition: False +test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth +embedder: { + 'type': 'FrozenOpenCLIPTextVisualEmbedder', + 'layer': 'penultimate', + 'pretrained': 'checkpoints/open_clip_pytorch_model.bin' +} + + +auto_encoder: { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt' +} + +UNet: { + 'type': 'UNetSD_UniAnimate', + 'config': None, + 'in_dim': 4, + 'dim': 320, + 'y_dim': 1024, + 'context_dim': 1024, + 'out_dim': 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'dropout': 0.1, + 'temporal_attention': True, + 'num_tokens': 4, + 'temporal_attn_times': 1, + 'use_checkpoint': True, + 'use_fps_condition': False, + 'use_sim_mask': False +} +video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose'] +Diffusion: { + 'type': 'DiffusionDDIM', + 'schedule': 'linear_sd', + 'schedule_param': { + 'num_timesteps': 1000, + "init_beta": 0.00085, + "last_beta": 0.0120, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', + 'loss_type': 'mse', + 'var_type': 'fixed_small', # 'fixed_large', + 'rescale_timesteps': False, + 'noise_strength': 0.1 +} +use_DiffusionDPM: False +CPU_CLIP_VAE: True \ No newline at end of file diff --git a/configs/UniAnimate_infer_long.yaml b/configs/UniAnimate_infer_long.yaml new file mode 100644 index 0000000..9b5f368 --- /dev/null +++ b/configs/UniAnimate_infer_long.yaml @@ -0,0 +1,101 @@ +# manual setting +# resolution: [512, 768] # or [768, 1216] +resolution: [768, 1216] +round: 1 +ddim_timesteps: 30 # among 25-50 +context_size: 32 +context_stride: 1 +context_overlap: 8 +seed: 7 +max_frames: "None" # 64, 96, "None" mean the length of original pose sequence +test_list_path: [ + # Format: [frame_interval, reference image, driving pose sequence] + [2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"], + [2, "data/images/musk.jpg", "data/saved_pose/musk"], + [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"], + [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"], + [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337_dance"], + [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full_dance"] +] + + +# default settings +TASK_TYPE: inference_unianimate_long_entrance +use_fp16: True +guide_scale: 2.5 +vit_resolution: [224, 224] +use_fp16: True +batch_size: 1 +latent_random_ref: True +chunk_size: 2 +decoder_bs: 2 +scale: 8 +use_fps_condition: False +test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth +partial_keys: [ + ['image', 'randomref', "dwpose"], + ] +embedder: { + 'type': 'FrozenOpenCLIPTextVisualEmbedder', + 'layer': 'penultimate', + 'pretrained': 'checkpoints/open_clip_pytorch_model.bin' +} + + +auto_encoder: { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt' +} + +UNet: { + 'type': 'UNetSD_UniAnimate', + 'config': None, + 'in_dim': 4, + 'dim': 320, + 'y_dim': 1024, + 'context_dim': 1024, + 'out_dim': 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'dropout': 0.1, + 'temporal_attention': True, + 'num_tokens': 4, + 'temporal_attn_times': 1, + 'use_checkpoint': True, + 'use_fps_condition': False, + 'use_sim_mask': False +} +video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose'] +Diffusion: { + 'type': 'DiffusionDDIMLong', + 'schedule': 'linear_sd', + 'schedule_param': { + 'num_timesteps': 1000, + "init_beta": 0.00085, + "last_beta": 0.0120, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', + 'loss_type': 'mse', + 'var_type': 'fixed_small', + 'rescale_timesteps': False, + 'noise_strength': 0.1 +} +CPU_CLIP_VAE: True +context_batch_size: 1 diff --git a/dwpose/__init__.py b/dwpose/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dwpose/__pycache__/__init__.cpython-310.pyc b/dwpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..1d0d874 Binary files /dev/null and b/dwpose/__pycache__/__init__.cpython-310.pyc differ diff --git a/dwpose/__pycache__/__init__.cpython-39.pyc b/dwpose/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..5773fcd Binary files /dev/null and b/dwpose/__pycache__/__init__.cpython-39.pyc differ diff --git a/dwpose/__pycache__/onnxdet.cpython-310.pyc b/dwpose/__pycache__/onnxdet.cpython-310.pyc new file mode 100644 index 0000000..4865389 Binary files /dev/null and b/dwpose/__pycache__/onnxdet.cpython-310.pyc differ diff --git a/dwpose/__pycache__/onnxdet.cpython-39.pyc b/dwpose/__pycache__/onnxdet.cpython-39.pyc new file mode 100644 index 0000000..2512f92 Binary files /dev/null and b/dwpose/__pycache__/onnxdet.cpython-39.pyc differ diff --git a/dwpose/__pycache__/onnxpose.cpython-310.pyc b/dwpose/__pycache__/onnxpose.cpython-310.pyc new file mode 100644 index 0000000..c795c70 Binary files /dev/null and b/dwpose/__pycache__/onnxpose.cpython-310.pyc differ diff --git a/dwpose/__pycache__/onnxpose.cpython-39.pyc b/dwpose/__pycache__/onnxpose.cpython-39.pyc new file mode 100644 index 0000000..1b9c5ec Binary files /dev/null and b/dwpose/__pycache__/onnxpose.cpython-39.pyc differ diff --git a/dwpose/__pycache__/util.cpython-310.pyc b/dwpose/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..e36d895 Binary files /dev/null and b/dwpose/__pycache__/util.cpython-310.pyc differ diff --git a/dwpose/__pycache__/util.cpython-39.pyc b/dwpose/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000..5a1073b Binary files /dev/null and b/dwpose/__pycache__/util.cpython-39.pyc differ diff --git a/dwpose/__pycache__/wholebody.cpython-310.pyc b/dwpose/__pycache__/wholebody.cpython-310.pyc new file mode 100644 index 0000000..1ce0769 Binary files /dev/null and b/dwpose/__pycache__/wholebody.cpython-310.pyc differ diff --git a/dwpose/__pycache__/wholebody.cpython-39.pyc b/dwpose/__pycache__/wholebody.cpython-39.pyc new file mode 100644 index 0000000..a1ba3f3 Binary files /dev/null and b/dwpose/__pycache__/wholebody.cpython-39.pyc differ diff --git a/dwpose/onnxdet.py b/dwpose/onnxdet.py new file mode 100644 index 0000000..4fab0c0 --- /dev/null +++ b/dwpose/onnxdet.py @@ -0,0 +1,127 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + + output = session.run(None, ort_inputs) + + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/dwpose/onnxpose.py b/dwpose/onnxpose.py new file mode 100644 index 0000000..72c1b00 --- /dev/null +++ b/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/dwpose/util.py b/dwpose/util.py new file mode 100644 index 0000000..94d9d9d --- /dev/null +++ b/dwpose/util.py @@ -0,0 +1,336 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_body_and_foot(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [14,19], [11, 20]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [170, 255, 255], [255, 255, 0]] + + for i in range(19): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(20): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/dwpose/wholebody.py b/dwpose/wholebody.py new file mode 100644 index 0000000..ddaec2b --- /dev/null +++ b/dwpose/wholebody.py @@ -0,0 +1,58 @@ +import os +import cv2 +import numpy as np + +import onnxruntime as ort +from ..dwpose.onnxdet import inference_detector +from ..dwpose.onnxpose import inference_pose + + +class Wholebody: + def __init__(self): + device = 'cuda' # 'cpu' # + providers = ['CPUExecutionProvider' + ] if device == 'cpu' else ['CUDAExecutionProvider'] + + current_directory = os.path.dirname(os.path.abspath(__file__)) + print("This file is located at "+os.path.dirname(os.path.abspath(__file__))) + parent_directory = os.path.dirname(current_directory) + + onnx_det = os.path.join(parent_directory, 'checkpoints/yolox_l.onnx') + onnx_pose = os.path.join(parent_directory, 'checkpoints/dw-ll_ucoco_384.onnx') + + # onnx_det = "../../checkpoints/yolox_l.onnx" + # onnx_pose = "../../checkpoints/dw-ll_ucoco_384.onnx" + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes= onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/lib/rotary_embedding_torch/__init__.py b/lib/rotary_embedding_torch/__init__.py new file mode 100644 index 0000000..3a2cfef --- /dev/null +++ b/lib/rotary_embedding_torch/__init__.py @@ -0,0 +1,6 @@ +from ..rotary_embedding_torch.rotary_embedding_torch import ( + apply_rotary_emb, + RotaryEmbedding, + apply_learned_rotations, + broadcat +) diff --git a/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e532687 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..bbd17f7 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc differ diff --git a/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc new file mode 100644 index 0000000..73761f6 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc differ diff --git a/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc new file mode 100644 index 0000000..37902f1 Binary files /dev/null and b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc differ diff --git a/lib/rotary_embedding_torch/rotary_embedding_torch.py b/lib/rotary_embedding_torch/rotary_embedding_torch.py new file mode 100644 index 0000000..8937875 --- /dev/null +++ b/lib/rotary_embedding_torch/rotary_embedding_torch.py @@ -0,0 +1,291 @@ +from __future__ import annotations +from math import pi, log + +import torch +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + +from einops import rearrange, repeat + +from typing import Literal + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# broadcat, as tortoise-tts was using it + +def broadcat(tensors, dim = -1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim = dim) + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +@autocast(enabled = False) +def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t_left, t, t_right), dim = -1) + + return out.type(dtype) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent = False) + + def get_seq_pos(self, seq_len, device, dtype, offset = 0): + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) + + freqs = self.forward(seq, seq_len = seq_len, offset = offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + q_scale = k_scale = 1. + + if self.use_xpos: + seq = self.get_seq_pos(k_len, dtype = dtype, device = device) + + q_scale = self.get_scale(seq[-q_len:]).type(dtype) + k_scale = self.get_scale(seq).type(dtype) + + rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + + freqs = self.forward(seq, seq_len = seq_len) + scale = self.get_scale(seq, seq_len = seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: Tensor, + seq_len: int | None = None, + offset = 0 + ): + assert self.use_xpos + + should_cache = ( + self.cache_if_possible and + exists(seq_len) + ) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + if should_cache: + self.tmp_store('cached_scales', scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + @autocast(enabled = False) + def forward( + self, + t: Tensor, + seq_len = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and \ + not self.learned_freq and \ + exists(seq_len) and \ + self.freqs_for != 'pixel' + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache: + self.tmp_store('cached_freqs', freqs.detach()) + + return freqs diff --git a/lib/simplejson/__init__.py b/lib/simplejson/__init__.py new file mode 100644 index 0000000..807b9da --- /dev/null +++ b/lib/simplejson/__init__.py @@ -0,0 +1,562 @@ +r"""JSON (JavaScript Object Notation) is a subset of +JavaScript syntax (ECMA-262 3rd edition) used as a lightweight data +interchange format. + +:mod:`simplejson` exposes an API familiar to users of the standard library +:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained +version of the :mod:`json` library contained in Python 2.6, but maintains +compatibility back to Python 2.5 and (currently) has significant performance +advantages, even without using the optional C extension for speedups. + +Encoding basic Python object hierarchies:: + + >>> import simplejson as json + >>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}]) + '["foo", {"bar": ["baz", null, 1.0, 2]}]' + >>> print(json.dumps("\"foo\bar")) + "\"foo\bar" + >>> print(json.dumps(u'\u1234')) + "\u1234" + >>> print(json.dumps('\\')) + "\\" + >>> print(json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True)) + {"a": 0, "b": 0, "c": 0} + >>> from simplejson.compat import StringIO + >>> io = StringIO() + >>> json.dump(['streaming API'], io) + >>> io.getvalue() + '["streaming API"]' + +Compact encoding:: + + >>> import simplejson as json + >>> obj = [1,2,3,{'4': 5, '6': 7}] + >>> json.dumps(obj, separators=(',',':'), sort_keys=True) + '[1,2,3,{"4":5,"6":7}]' + +Pretty printing:: + + >>> import simplejson as json + >>> print(json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ')) + { + "4": 5, + "6": 7 + } + +Decoding JSON:: + + >>> import simplejson as json + >>> obj = [u'foo', {u'bar': [u'baz', None, 1.0, 2]}] + >>> json.loads('["foo", {"bar":["baz", null, 1.0, 2]}]') == obj + True + >>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar' + True + >>> from simplejson.compat import StringIO + >>> io = StringIO('["streaming API"]') + >>> json.load(io)[0] == 'streaming API' + True + +Specializing JSON object decoding:: + + >>> import simplejson as json + >>> def as_complex(dct): + ... if '__complex__' in dct: + ... return complex(dct['real'], dct['imag']) + ... return dct + ... + >>> json.loads('{"__complex__": true, "real": 1, "imag": 2}', + ... object_hook=as_complex) + (1+2j) + >>> from decimal import Decimal + >>> json.loads('1.1', parse_float=Decimal) == Decimal('1.1') + True + +Specializing JSON object encoding:: + + >>> import simplejson as json + >>> def encode_complex(obj): + ... if isinstance(obj, complex): + ... return [obj.real, obj.imag] + ... raise TypeError('Object of type %s is not JSON serializable' % + ... obj.__class__.__name__) + ... + >>> json.dumps(2 + 1j, default=encode_complex) + '[2.0, 1.0]' + >>> json.JSONEncoder(default=encode_complex).encode(2 + 1j) + '[2.0, 1.0]' + >>> ''.join(json.JSONEncoder(default=encode_complex).iterencode(2 + 1j)) + '[2.0, 1.0]' + +Using simplejson.tool from the shell to validate and pretty-print:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 3 (char 2) + +Parsing multiple documents serialized as JSON lines (newline-delimited JSON):: + + >>> import simplejson as json + >>> def loads_lines(docs): + ... for doc in docs.splitlines(): + ... yield json.loads(doc) + ... + >>> sum(doc["count"] for doc in loads_lines('{"count":1}\n{"count":2}\n{"count":3}\n')) + 6 + +Serializing multiple objects to JSON lines (newline-delimited JSON):: + + >>> import simplejson as json + >>> def dumps_lines(objs): + ... for obj in objs: + ... yield json.dumps(obj, separators=(',',':')) + '\n' + ... + >>> ''.join(dumps_lines([{'count': 1}, {'count': 2}, {'count': 3}])) + '{"count":1}\n{"count":2}\n{"count":3}\n' + +""" +from __future__ import absolute_import +__version__ = '3.19.2' +__all__ = [ + 'dump', 'dumps', 'load', 'loads', + 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', + 'OrderedDict', 'simple_first', 'RawJSON' +] + +__author__ = 'Bob Ippolito ' + +from decimal import Decimal + +from .errors import JSONDecodeError +from .raw_json import RawJSON +from .decoder import JSONDecoder +from .encoder import JSONEncoder, JSONEncoderForHTML +def _import_OrderedDict(): + import collections + try: + return collections.OrderedDict + except AttributeError: + from . import ordered_dict + return ordered_dict.OrderedDict +OrderedDict = _import_OrderedDict() + +def _import_c_make_encoder(): + try: + from ._speedups import make_encoder + return make_encoder + except ImportError: + return None + +_default_encoder = JSONEncoder() + +def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=False, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, + iterable_as_array=False, **kw): + """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a + ``.write()``-supporting file-like object). + + If *skipkeys* is true then ``dict`` keys that are not basic types + (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If *ensure_ascii* is false (default: ``True``), then the output may + contain non-ASCII characters, so long as they do not need to be escaped + by JSON. When it is true, all non-ASCII characters are escaped. + + If *allow_nan* is true (default: ``False``), then out of range ``float`` + values (``nan``, ``inf``, ``-inf``) will be serialized to + their JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``) + instead of raising a ValueError. See + *ignore_nan* for ECMA-262 compliant behavior. + + If *indent* is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. + + If specified, *separators* should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. + + *encoding* is the character encoding for str instances, default is UTF-8. + + *default(obj)* is a function that should return a serializable version + of obj or raise ``TypeError``. The default simply raises ``TypeError``. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. Note that this is still a + lossy operation that will not round-trip correctly and should be used + sparingly. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. NOTE: You should use *default* or *for_json* instead + of subclassing whenever possible. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and not allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not iterable_as_array + and not bigint_as_string and not sort_keys + and not item_sort_key and not for_json + and not ignore_nan and int_as_string_bitcount is None + and not kw + ): + iterable = _default_encoder.iterencode(obj) + else: + if cls is None: + cls = JSONEncoder + iterable = cls(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, + default=default, use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + iterable_as_array=iterable_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, + **kw).iterencode(obj) + # could accelerate with writelines in some versions of Python, at + # a debuggability cost + for chunk in iterable: + fp.write(chunk) + + +def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=False, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, + iterable_as_array=False, **kw): + """Serialize ``obj`` to a JSON formatted ``str``. + + If ``skipkeys`` is true then ``dict`` keys that are not basic types + (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If *ensure_ascii* is false (default: ``True``), then the output may + contain non-ASCII characters, so long as they do not need to be escaped + by JSON. When it is true, all non-ASCII characters are escaped. + + If ``check_circular`` is false, then the circular reference check + for container types will be skipped and a circular reference will + result in an ``OverflowError`` (or worse). + + If *allow_nan* is true (default: ``False``), then out of range ``float`` + values (``nan``, ``inf``, ``-inf``) will be serialized to + their JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``) + instead of raising a ValueError. See + *ignore_nan* for ECMA-262 compliant behavior. + + If ``indent`` is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, ``separators`` should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. + + ``encoding`` is the character encoding for bytes instances, default is + UTF-8. + + ``default(obj)`` is a function that should return a serializable version + of obj or raise TypeError. The default simply raises TypeError. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If *bigint_as_string* is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. NOTE: You should use *default* instead of subclassing + whenever possible. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and not allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not iterable_as_array + and not bigint_as_string and not sort_keys + and not item_sort_key and not for_json + and not ignore_nan and int_as_string_bitcount is None + and not kw + ): + return _default_encoder.encode(obj) + if cls is None: + cls = JSONEncoder + return cls( + skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, default=default, + use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + iterable_as_array=iterable_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, + **kw).encode(obj) + + +_default_decoder = JSONDecoder() + + +def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, allow_nan=False, **kw): + """Deserialize ``fp`` (a ``.read()``-supporting file-like object containing + a JSON document as `str` or `bytes`) to a Python object. + + *encoding* determines the encoding used to interpret any + `bytes` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding `str` objects. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity`` + and enable the use of the deprecated *parse_constant*. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. + + """ + return loads(fp.read(), + encoding=encoding, cls=cls, object_hook=object_hook, + parse_float=parse_float, parse_int=parse_int, + parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, + use_decimal=use_decimal, allow_nan=allow_nan, **kw) + + +def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, allow_nan=False, **kw): + """Deserialize ``s`` (a ``str`` or ``unicode`` instance containing a JSON + document) to a Python object. + + *encoding* determines the encoding used to interpret any + :class:`bytes` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity`` + and enable the use of the deprecated *parse_constant*. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. + + """ + if (cls is None and encoding is None and object_hook is None and + parse_int is None and parse_float is None and + parse_constant is None and object_pairs_hook is None + and not use_decimal and not allow_nan and not kw): + return _default_decoder.decode(s) + if cls is None: + cls = JSONDecoder + if object_hook is not None: + kw['object_hook'] = object_hook + if object_pairs_hook is not None: + kw['object_pairs_hook'] = object_pairs_hook + if parse_float is not None: + kw['parse_float'] = parse_float + if parse_int is not None: + kw['parse_int'] = parse_int + if parse_constant is not None: + kw['parse_constant'] = parse_constant + if use_decimal: + if parse_float is not None: + raise TypeError("use_decimal=True implies parse_float=Decimal") + kw['parse_float'] = Decimal + if allow_nan: + kw['allow_nan'] = True + return cls(encoding=encoding, **kw).decode(s) + + +def _toggle_speedups(enabled): + from . import decoder as dec + from . import encoder as enc + from . import scanner as scan + c_make_encoder = _import_c_make_encoder() + if enabled: + dec.scanstring = dec.c_scanstring or dec.py_scanstring + enc.c_make_encoder = c_make_encoder + enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or + enc.py_encode_basestring_ascii) + scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner + else: + dec.scanstring = dec.py_scanstring + enc.c_make_encoder = None + enc.encode_basestring_ascii = enc.py_encode_basestring_ascii + scan.make_scanner = scan.py_make_scanner + dec.make_scanner = scan.make_scanner + global _default_decoder + _default_decoder = JSONDecoder() + global _default_encoder + _default_encoder = JSONEncoder() + +def simple_first(kv): + """Helper function to pass to item_sort_key to sort simple + elements to the top, then container elements. + """ + return (isinstance(kv[1], (list, dict, tuple)), kv[0]) diff --git a/lib/simplejson/_speedups.cp39-win_amd64.pyd b/lib/simplejson/_speedups.cp39-win_amd64.pyd new file mode 100644 index 0000000..f6e8451 Binary files /dev/null and b/lib/simplejson/_speedups.cp39-win_amd64.pyd differ diff --git a/lib/simplejson/compat.py b/lib/simplejson/compat.py new file mode 100644 index 0000000..7d7d76a --- /dev/null +++ b/lib/simplejson/compat.py @@ -0,0 +1,34 @@ +"""Python 3 compatibility shims +""" +import sys +if sys.version_info[0] < 3: + PY3 = False + def b(s): + return s + try: + from cStringIO import StringIO + except ImportError: + from StringIO import StringIO + BytesIO = StringIO + text_type = unicode + binary_type = str + string_types = (basestring,) + integer_types = (int, long) + unichr = unichr + reload_module = reload +else: + PY3 = True + if sys.version_info[:2] >= (3, 4): + from importlib import reload as reload_module + else: + from imp import reload as reload_module + def b(s): + return bytes(s, 'latin1') + from io import StringIO, BytesIO + text_type = str + binary_type = bytes + string_types = (str,) + integer_types = (int,) + unichr = chr + +long_type = integer_types[-1] diff --git a/lib/simplejson/decoder.py b/lib/simplejson/decoder.py new file mode 100644 index 0000000..7ed5ea8 --- /dev/null +++ b/lib/simplejson/decoder.py @@ -0,0 +1,416 @@ +"""Implementation of JSONDecoder +""" +from __future__ import absolute_import +import re +import sys +import struct +from .compat import PY3, unichr +from .scanner import make_scanner, JSONDecodeError + +def _import_c_scanstring(): + try: + from ._speedups import scanstring + return scanstring + except ImportError: + return None +c_scanstring = _import_c_scanstring() + +# NOTE (3.1.0): JSONDecodeError may still be imported from this module for +# compatibility, but it was never in the __all__ +__all__ = ['JSONDecoder'] + +FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL + +def _floatconstants(): + if sys.version_info < (2, 6): + _BYTES = '7FF80000000000007FF0000000000000'.decode('hex') + nan, inf = struct.unpack('>dd', _BYTES) + else: + nan = float('nan') + inf = float('inf') + return nan, inf, -inf + +NaN, PosInf, NegInf = _floatconstants() + +_CONSTANTS = { + '-Infinity': NegInf, + 'Infinity': PosInf, + 'NaN': NaN, +} + +STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS) +BACKSLASH = { + '"': u'"', '\\': u'\\', '/': u'/', + 'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t', +} + +DEFAULT_ENCODING = "utf-8" + +if hasattr(sys, 'get_int_max_str_digits'): + bounded_int = int +else: + def bounded_int(s, INT_MAX_STR_DIGITS=4300): + """Backport of the integer string length conversion limitation + + https://docs.python.org/3/library/stdtypes.html#int-max-str-digits + """ + if len(s) > INT_MAX_STR_DIGITS: + raise ValueError("Exceeds the limit (%s) for integer string conversion: value has %s digits" % (INT_MAX_STR_DIGITS, len(s))) + return int(s) + + +def scan_four_digit_hex(s, end, _m=re.compile(r'^[0-9a-fA-F]{4}$').match): + """Scan a four digit hex number from s[end:end + 4] + """ + msg = "Invalid \\uXXXX escape sequence" + esc = s[end:end + 4] + if not _m(esc): + raise JSONDecodeError(msg, s, end - 2) + try: + return int(esc, 16), end + 4 + except ValueError: + raise JSONDecodeError(msg, s, end - 2) + +def py_scanstring(s, end, encoding=None, strict=True, + _b=BACKSLASH, _m=STRINGCHUNK.match, _join=u''.join, + _PY3=PY3, _maxunicode=sys.maxunicode, + _scan_four_digit_hex=scan_four_digit_hex): + """Scan the string s for a JSON string. End is the index of the + character in s after the quote that started the JSON string. + Unescapes all valid JSON string escape sequences and raises ValueError + on attempt to decode an invalid string. If strict is False then literal + control characters are allowed in the string. + + Returns a tuple of the decoded string and the index of the character in s + after the end quote.""" + if encoding is None: + encoding = DEFAULT_ENCODING + chunks = [] + _append = chunks.append + begin = end - 1 + while 1: + chunk = _m(s, end) + if chunk is None: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + prev_end = end + end = chunk.end() + content, terminator = chunk.groups() + # Content is contains zero or more unescaped string characters + if content: + if not _PY3 and not isinstance(content, unicode): + content = unicode(content, encoding) + _append(content) + # Terminator is the end of string, a literal control character, + # or a backslash denoting that an escape sequence follows + if terminator == '"': + break + elif terminator != '\\': + if strict: + msg = "Invalid control character %r at" + raise JSONDecodeError(msg, s, prev_end) + else: + _append(terminator) + continue + try: + esc = s[end] + except IndexError: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + # If not a unicode escape sequence, must be in the lookup table + if esc != 'u': + try: + char = _b[esc] + except KeyError: + msg = "Invalid \\X escape sequence %r" + raise JSONDecodeError(msg, s, end) + end += 1 + else: + # Unicode escape sequence + uni, end = _scan_four_digit_hex(s, end + 1) + # Check for surrogate pair on UCS-4 systems + # Note that this will join high/low surrogate pairs + # but will also pass unpaired surrogates through + if (_maxunicode > 65535 and + uni & 0xfc00 == 0xd800 and + s[end:end + 2] == '\\u'): + uni2, end2 = _scan_four_digit_hex(s, end + 2) + if uni2 & 0xfc00 == 0xdc00: + uni = 0x10000 + (((uni - 0xd800) << 10) | + (uni2 - 0xdc00)) + end = end2 + char = unichr(uni) + # Append the unescaped character + _append(char) + return _join(chunks), end + + +# Use speedup if available +scanstring = c_scanstring or py_scanstring + +WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS) +WHITESPACE_STR = ' \t\n\r' + +def JSONObject(state, encoding, strict, scan_once, object_hook, + object_pairs_hook, memo=None, + _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state + # Backwards compatibility + if memo is None: + memo = {} + memo_get = memo.setdefault + pairs = [] + # Use a slice to prevent IndexError from being raised, the following + # check will raise a more specific ValueError if the string is empty + nextchar = s[end:end + 1] + # Normally we expect nextchar == '"' + if nextchar != '"': + if nextchar in _ws: + end = _w(s, end).end() + nextchar = s[end:end + 1] + # Trivial empty object + if nextchar == '}': + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + 1 + pairs = {} + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + 1 + elif nextchar != '"': + raise JSONDecodeError( + "Expecting property name enclosed in double quotes or '}'", + s, end) + end += 1 + while True: + key, end = scanstring(s, end, encoding, strict) + key = memo_get(key, key) + + # To skip some function call overhead we optimize the fast paths where + # the JSON key separator is ": " or just ":". + if s[end:end + 1] != ':': + end = _w(s, end).end() + if s[end:end + 1] != ':': + raise JSONDecodeError("Expecting ':' delimiter", s, end) + + end += 1 + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + value, end = scan_once(s, end) + pairs.append((key, value)) + + try: + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + end += 1 + + if nextchar == '}': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting ',' delimiter or '}'", s, end - 1) + + try: + nextchar = s[end] + if nextchar in _ws: + end += 1 + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + + end += 1 + if nextchar != '"': + raise JSONDecodeError( + "Expecting property name enclosed in double quotes", + s, end - 1) + + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + pairs = dict(pairs) + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + +def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state + values = [] + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + # Look-ahead for trivial empty array + if nextchar == ']': + return values, end + 1 + elif nextchar == '': + raise JSONDecodeError("Expecting value or ']'", s, end) + _append = values.append + while True: + value, end = scan_once(s, end) + _append(value) + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + end += 1 + if nextchar == ']': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting ',' delimiter or ']'", s, end - 1) + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + return values, end + +class JSONDecoder(object): + """Simple JSON decoder + + Performs the following translations in decoding by default: + + +---------------+-------------------+ + | JSON | Python | + +===============+===================+ + | object | dict | + +---------------+-------------------+ + | array | list | + +---------------+-------------------+ + | string | str, unicode | + +---------------+-------------------+ + | number (int) | int, long | + +---------------+-------------------+ + | number (real) | float | + +---------------+-------------------+ + | true | True | + +---------------+-------------------+ + | false | False | + +---------------+-------------------+ + | null | None | + +---------------+-------------------+ + + When allow_nan=True, it also understands + ``NaN``, ``Infinity``, and ``-Infinity`` as + their corresponding ``float`` values, which is outside the JSON spec. + + """ + + def __init__(self, encoding=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, strict=True, + object_pairs_hook=None, allow_nan=False): + """ + *encoding* determines the encoding used to interpret any + :class:`str` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + Note that currently only encodings that are a superset of ASCII work, + strings of other encodings should be passed in as :class:`unicode`. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + *strict* controls the parser's behavior when it encounters an + invalid control character in a string. The default setting of + ``True`` means that unescaped control characters are parse errors, if + ``False`` then control characters will be allowed in strings. + + """ + if encoding is None: + encoding = DEFAULT_ENCODING + self.encoding = encoding + self.object_hook = object_hook + self.object_pairs_hook = object_pairs_hook + self.parse_float = parse_float or float + self.parse_int = parse_int or bounded_int + self.parse_constant = parse_constant or (allow_nan and _CONSTANTS.__getitem__ or None) + self.strict = strict + self.parse_object = JSONObject + self.parse_array = JSONArray + self.parse_string = scanstring + self.memo = {} + self.scan_once = make_scanner(self) + + def decode(self, s, _w=WHITESPACE.match, _PY3=PY3): + """Return the Python representation of ``s`` (a ``str`` or ``unicode`` + instance containing a JSON document) + + """ + if _PY3 and isinstance(s, bytes): + s = str(s, self.encoding) + obj, end = self.raw_decode(s) + end = _w(s, end).end() + if end != len(s): + raise JSONDecodeError("Extra data", s, end, len(s)) + return obj + + def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3): + """Decode a JSON document from ``s`` (a ``str`` or ``unicode`` + beginning with a JSON document) and return a 2-tuple of the Python + representation and the index in ``s`` where the document ended. + Optionally, ``idx`` can be used to specify an offset in ``s`` where + the JSON document begins. + + This can be used to decode a JSON document from a string that may + have extraneous data at the end. + + """ + if idx < 0: + # Ensure that raw_decode bails on negative indexes, the regex + # would otherwise mask this behavior. #98 + raise JSONDecodeError('Expecting value', s, idx) + if _PY3 and not isinstance(s, str): + raise TypeError("Input string must be text, not bytes") + # strip UTF-8 bom + if len(s) > idx: + ord0 = ord(s[idx]) + if ord0 == 0xfeff: + idx += 1 + elif ord0 == 0xef and s[idx:idx + 3] == '\xef\xbb\xbf': + idx += 3 + return self.scan_once(s, idx=_w(s, idx).end()) diff --git a/lib/simplejson/encoder.py b/lib/simplejson/encoder.py new file mode 100644 index 0000000..ed3f281 --- /dev/null +++ b/lib/simplejson/encoder.py @@ -0,0 +1,740 @@ +"""Implementation of JSONEncoder +""" +from __future__ import absolute_import +import re +from operator import itemgetter +# Do not import Decimal directly to avoid reload issues +import decimal +from .compat import binary_type, text_type, string_types, integer_types, PY3 +def _import_speedups(): + try: + from . import _speedups + return _speedups.encode_basestring_ascii, _speedups.make_encoder + except ImportError: + return None, None +c_encode_basestring_ascii, c_make_encoder = _import_speedups() + +from .decoder import PosInf +from .raw_json import RawJSON + +ESCAPE = re.compile(r'[\x00-\x1f\\"]') +ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') +HAS_UTF8 = re.compile(r'[\x80-\xff]') +ESCAPE_DCT = { + '\\': '\\\\', + '"': '\\"', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', +} +for i in range(0x20): + #ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) + ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) +del i + +FLOAT_REPR = repr + +def encode_basestring(s, _PY3=PY3, _q=u'"'): + """Return a JSON representation of a Python string + + """ + if _PY3: + if isinstance(s, bytes): + s = str(s, 'utf-8') + elif type(s) is not str: + # convert an str subclass instance to exact str + # raise a TypeError otherwise + s = str.__str__(s) + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = unicode(s, 'utf-8') + elif type(s) not in (str, unicode): + # convert an str subclass instance to exact str + # convert a unicode subclass instance to exact unicode + # raise a TypeError otherwise + if isinstance(s, str): + s = str.__str__(s) + else: + s = unicode.__getnewargs__(s)[0] + def replace(match): + return ESCAPE_DCT[match.group(0)] + return _q + ESCAPE.sub(replace, s) + _q + + +def py_encode_basestring_ascii(s, _PY3=PY3): + """Return an ASCII-only JSON representation of a Python string + + """ + if _PY3: + if isinstance(s, bytes): + s = str(s, 'utf-8') + elif type(s) is not str: + # convert an str subclass instance to exact str + # raise a TypeError otherwise + s = str.__str__(s) + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = unicode(s, 'utf-8') + elif type(s) not in (str, unicode): + # convert an str subclass instance to exact str + # convert a unicode subclass instance to exact unicode + # raise a TypeError otherwise + if isinstance(s, str): + s = str.__str__(s) + else: + s = unicode.__getnewargs__(s)[0] + def replace(match): + s = match.group(0) + try: + return ESCAPE_DCT[s] + except KeyError: + n = ord(s) + if n < 0x10000: + #return '\\u{0:04x}'.format(n) + return '\\u%04x' % (n,) + else: + # surrogate pair + n -= 0x10000 + s1 = 0xd800 | ((n >> 10) & 0x3ff) + s2 = 0xdc00 | (n & 0x3ff) + #return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) + return '\\u%04x\\u%04x' % (s1, s2) + return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' + + +encode_basestring_ascii = ( + c_encode_basestring_ascii or py_encode_basestring_ascii) + +class JSONEncoder(object): + """Extensible JSON encoder for Python data structures. + + Supports the following objects and types by default: + + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict, namedtuple | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str, unicode | string | + +-------------------+---------------+ + | int, long, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ + + To extend this to recognize other objects, subclass and implement a + ``.default()`` method with another method that returns a serializable + object for ``o`` if possible, otherwise it should call the superclass + implementation (to raise ``TypeError``). + + """ + item_separator = ', ' + key_separator = ': ' + + def __init__(self, skipkeys=False, ensure_ascii=True, + check_circular=True, allow_nan=False, sort_keys=False, + indent=None, separators=None, encoding='utf-8', default=None, + use_decimal=True, namedtuple_as_object=True, + tuple_as_array=True, bigint_as_string=False, + item_sort_key=None, for_json=False, ignore_nan=False, + int_as_string_bitcount=None, iterable_as_array=False): + """Constructor for JSONEncoder, with sensible defaults. + + If skipkeys is false, then it is a TypeError to attempt + encoding of keys that are not str, int, long, float or None. If + skipkeys is True, such items are simply skipped. + + If ensure_ascii is true, the output is guaranteed to be str + objects with all incoming unicode characters escaped. If + ensure_ascii is false, the output will be unicode object. + + If check_circular is true, then lists, dicts, and custom encoded + objects will be checked for circular references during encoding to + prevent an infinite recursion (which would cause an OverflowError). + Otherwise, no such check takes place. + + If allow_nan is true (default: False), then out of range float + values (nan, inf, -inf) will be serialized to + their JavaScript equivalents (NaN, Infinity, -Infinity) + instead of raising a ValueError. See + ignore_nan for ECMA-262 compliant behavior. + + If sort_keys is true, then the output of dictionaries will be + sorted by key; this is useful for regression tests to ensure + that JSON serializations can be compared on a day-to-day basis. + + If indent is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, separators should be an (item_separator, key_separator) + tuple. The default is (', ', ': ') if *indent* is ``None`` and + (',', ': ') otherwise. To get the most compact JSON representation, + you should specify (',', ':') to eliminate whitespace. + + If specified, default is a function that gets called for objects + that can't otherwise be serialized. It should return a JSON encodable + version of the object or raise a ``TypeError``. + + If encoding is not None, then all input strings will be + transformed into unicode using that encoding prior to JSON-encoding. + The default is UTF-8. + + If use_decimal is true (default: ``True``), ``decimal.Decimal`` will + be supported directly by the encoder. For the inverse, decode JSON + with ``parse_float=decimal.Decimal``. + + If namedtuple_as_object is true (the default), objects with + ``_asdict()`` methods will be encoded as JSON objects. + + If tuple_as_array is true (the default), tuple (and subclasses) will + be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If bigint_as_string is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If int_as_string_bitcount is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, item_sort_key is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. + + If for_json is true (not the default), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized + as ``null`` in compliance with the ECMA-262 specification. If true, + this will override *allow_nan*. + + """ + + self.skipkeys = skipkeys + self.ensure_ascii = ensure_ascii + self.check_circular = check_circular + self.allow_nan = allow_nan + self.sort_keys = sort_keys + self.use_decimal = use_decimal + self.namedtuple_as_object = namedtuple_as_object + self.tuple_as_array = tuple_as_array + self.iterable_as_array = iterable_as_array + self.bigint_as_string = bigint_as_string + self.item_sort_key = item_sort_key + self.for_json = for_json + self.ignore_nan = ignore_nan + self.int_as_string_bitcount = int_as_string_bitcount + if indent is not None and not isinstance(indent, string_types): + indent = indent * ' ' + self.indent = indent + if separators is not None: + self.item_separator, self.key_separator = separators + elif indent is not None: + self.item_separator = ',' + if default is not None: + self.default = default + self.encoding = encoding + + def default(self, o): + """Implement this method in a subclass such that it returns + a serializable object for ``o``, or calls the base implementation + (to raise a ``TypeError``). + + For example, to support arbitrary iterators, you could + implement default like this:: + + def default(self, o): + try: + iterable = iter(o) + except TypeError: + pass + else: + return list(iterable) + return JSONEncoder.default(self, o) + + """ + raise TypeError('Object of type %s is not JSON serializable' % + o.__class__.__name__) + + def encode(self, o): + """Return a JSON string representation of a Python data structure. + + >>> from simplejson import JSONEncoder + >>> JSONEncoder().encode({"foo": ["bar", "baz"]}) + '{"foo": ["bar", "baz"]}' + + """ + # This is for extremely simple cases and benchmarks. + if isinstance(o, binary_type): + _encoding = self.encoding + if (_encoding is not None and not (_encoding == 'utf-8')): + o = text_type(o, _encoding) + if isinstance(o, string_types): + if self.ensure_ascii: + return encode_basestring_ascii(o) + else: + return encode_basestring(o) + # This doesn't pass the iterator directly to ''.join() because the + # exceptions aren't as detailed. The list call should be roughly + # equivalent to the PySequence_Fast that ''.join() would do. + chunks = self.iterencode(o) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + if self.ensure_ascii: + return ''.join(chunks) + else: + return u''.join(chunks) + + def iterencode(self, o): + """Encode the given object and yield each string + representation as available. + + For example:: + + for chunk in JSONEncoder().iterencode(bigobject): + mysocket.write(chunk) + + """ + if self.check_circular: + markers = {} + else: + markers = None + if self.ensure_ascii: + _encoder = encode_basestring_ascii + else: + _encoder = encode_basestring + if self.encoding != 'utf-8' and self.encoding is not None: + def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding): + if isinstance(o, binary_type): + o = text_type(o, _encoding) + return _orig_encoder(o) + + def floatstr(o, allow_nan=self.allow_nan, ignore_nan=self.ignore_nan, + _repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf): + # Check for specials. Note that this type of test is processor + # and/or platform-specific, so do tests which don't depend on + # the internals. + + if o != o: + text = 'NaN' + elif o == _inf: + text = 'Infinity' + elif o == _neginf: + text = '-Infinity' + else: + if type(o) != float: + # See #118, do not trust custom str/repr + o = float(o) + return _repr(o) + + if ignore_nan: + text = 'null' + elif not allow_nan: + raise ValueError( + "Out of range float values are not JSON compliant: " + + repr(o)) + + return text + + key_memo = {} + int_as_string_bitcount = ( + 53 if self.bigint_as_string else self.int_as_string_bitcount) + if (c_make_encoder is not None and self.indent is None): + _iterencode = c_make_encoder( + markers, self.default, _encoder, self.indent, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.allow_nan, key_memo, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.ignore_nan, decimal.Decimal, self.iterable_as_array) + else: + _iterencode = _make_iterencode( + markers, self.default, _encoder, self.indent, floatstr, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.iterable_as_array, Decimal=decimal.Decimal) + try: + return _iterencode(o, 0) + finally: + key_memo.clear() + + +class JSONEncoderForHTML(JSONEncoder): + """An encoder that produces JSON safe to embed in HTML. + + To embed JSON content in, say, a script tag on a web page, the + characters &, < and > should be escaped. They cannot be escaped + with the usual entities (e.g. &) because they are not expanded + within ' + self.assertEqual( + r'"\u003c/script\u003e\u003cscript\u003e' + r'alert(\"gotcha\")\u003c/script\u003e"', + self.encoder.encode(bad_string)) + self.assertEqual( + bad_string, self.decoder.decode( + self.encoder.encode(bad_string))) diff --git a/lib/simplejson/tests/test_errors.py b/lib/simplejson/tests/test_errors.py new file mode 100644 index 0000000..c6e8688 --- /dev/null +++ b/lib/simplejson/tests/test_errors.py @@ -0,0 +1,68 @@ +import sys, pickle +from unittest import TestCase + +import simplejson as json +from simplejson.compat import text_type, b + +class TestErrors(TestCase): + def test_string_keys_error(self): + data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}] + try: + json.dumps(data) + except TypeError: + err = sys.exc_info()[1] + else: + self.fail('Expected TypeError') + self.assertEqual(str(err), + 'keys must be str, int, float, bool or None, not tuple') + + def test_not_serializable(self): + try: + json.dumps(json) + except TypeError: + err = sys.exc_info()[1] + else: + self.fail('Expected TypeError') + self.assertEqual(str(err), + 'Object of type module is not JSON serializable') + + def test_decode_error(self): + err = None + try: + json.loads('{}\na\nb') + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + self.assertEqual(err.lineno, 2) + self.assertEqual(err.colno, 1) + self.assertEqual(err.endlineno, 3) + self.assertEqual(err.endcolno, 2) + + def test_scan_error(self): + err = None + for t in (text_type, b): + try: + json.loads(t('{"asdf": "')) + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, 10) + + def test_error_is_pickable(self): + err = None + try: + json.loads('{}\na\nb') + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + s = pickle.dumps(err) + e = pickle.loads(s) + + self.assertEqual(err.msg, e.msg) + self.assertEqual(err.doc, e.doc) + self.assertEqual(err.pos, e.pos) + self.assertEqual(err.end, e.end) diff --git a/lib/simplejson/tests/test_fail.py b/lib/simplejson/tests/test_fail.py new file mode 100644 index 0000000..54b7414 --- /dev/null +++ b/lib/simplejson/tests/test_fail.py @@ -0,0 +1,178 @@ +import sys +from unittest import TestCase + +import simplejson as json + +# 2007-10-05 +JSONDOCS = [ + # http://json.org/JSON_checker/test/fail1.json + '"A JSON payload should be an object or array, not a string."', + # http://json.org/JSON_checker/test/fail2.json + '["Unclosed array"', + # http://json.org/JSON_checker/test/fail3.json + '{unquoted_key: "keys must be quoted"}', + # http://json.org/JSON_checker/test/fail4.json + '["extra comma",]', + # http://json.org/JSON_checker/test/fail5.json + '["double extra comma",,]', + # http://json.org/JSON_checker/test/fail6.json + '[ , "<-- missing value"]', + # http://json.org/JSON_checker/test/fail7.json + '["Comma after the close"],', + # http://json.org/JSON_checker/test/fail8.json + '["Extra close"]]', + # http://json.org/JSON_checker/test/fail9.json + '{"Extra comma": true,}', + # http://json.org/JSON_checker/test/fail10.json + '{"Extra value after close": true} "misplaced quoted value"', + # http://json.org/JSON_checker/test/fail11.json + '{"Illegal expression": 1 + 2}', + # http://json.org/JSON_checker/test/fail12.json + '{"Illegal invocation": alert()}', + # http://json.org/JSON_checker/test/fail13.json + '{"Numbers cannot have leading zeroes": 013}', + # http://json.org/JSON_checker/test/fail14.json + '{"Numbers cannot be hex": 0x14}', + # http://json.org/JSON_checker/test/fail15.json + '["Illegal backslash escape: \\x15"]', + # http://json.org/JSON_checker/test/fail16.json + '[\\naked]', + # http://json.org/JSON_checker/test/fail17.json + '["Illegal backslash escape: \\017"]', + # http://json.org/JSON_checker/test/fail18.json + '[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', + # http://json.org/JSON_checker/test/fail19.json + '{"Missing colon" null}', + # http://json.org/JSON_checker/test/fail20.json + '{"Double colon":: null}', + # http://json.org/JSON_checker/test/fail21.json + '{"Comma instead of colon", null}', + # http://json.org/JSON_checker/test/fail22.json + '["Colon instead of comma": false]', + # http://json.org/JSON_checker/test/fail23.json + '["Bad value", truth]', + # http://json.org/JSON_checker/test/fail24.json + "['single quote']", + # http://json.org/JSON_checker/test/fail25.json + '["\ttab\tcharacter\tin\tstring\t"]', + # http://json.org/JSON_checker/test/fail26.json + '["tab\\ character\\ in\\ string\\ "]', + # http://json.org/JSON_checker/test/fail27.json + '["line\nbreak"]', + # http://json.org/JSON_checker/test/fail28.json + '["line\\\nbreak"]', + # http://json.org/JSON_checker/test/fail29.json + '[0e]', + # http://json.org/JSON_checker/test/fail30.json + '[0e+]', + # http://json.org/JSON_checker/test/fail31.json + '[0e+-1]', + # http://json.org/JSON_checker/test/fail32.json + '{"Comma instead if closing brace": true,', + # http://json.org/JSON_checker/test/fail33.json + '["mismatch"}', + # http://code.google.com/p/simplejson/issues/detail?id=3 + u'["A\u001FZ control characters in string"]', + # misc based on coverage + '{', + '{]', + '{"foo": "bar"]', + '{"foo": "bar"', + 'nul', + 'nulx', + '-', + '-x', + '-e', + '-e0', + '-Infinite', + '-Inf', + 'Infinit', + 'Infinite', + 'NaM', + 'NuN', + 'falsy', + 'fal', + 'trug', + 'tru', + '1e', + '1ex', + '1e-', + '1e-x', +] + +SKIPS = { + 1: "why not have a string payload?", + 18: "spec doesn't specify any nesting limitations", +} + +class TestFail(TestCase): + def test_failures(self): + for idx, doc in enumerate(JSONDOCS): + idx = idx + 1 + if idx in SKIPS: + json.loads(doc) + continue + try: + json.loads(doc) + except json.JSONDecodeError: + pass + else: + self.fail("Expected failure for fail%d.json: %r" % (idx, doc)) + + def test_array_decoder_issue46(self): + # http://code.google.com/p/simplejson/issues/detail?id=46 + for doc in [u'[,]', '[,]']: + try: + json.loads(doc) + except json.JSONDecodeError: + e = sys.exc_info()[1] + self.assertEqual(e.pos, 1) + self.assertEqual(e.lineno, 1) + self.assertEqual(e.colno, 2) + except Exception: + e = sys.exc_info()[1] + self.fail("Unexpected exception raised %r %s" % (e, e)) + else: + self.fail("Unexpected success parsing '[,]'") + + def test_truncated_input(self): + test_cases = [ + ('', 'Expecting value', 0), + ('[', "Expecting value or ']'", 1), + ('[42', "Expecting ',' delimiter", 3), + ('[42,', 'Expecting value', 4), + ('["', 'Unterminated string starting at', 1), + ('["spam', 'Unterminated string starting at', 1), + ('["spam"', "Expecting ',' delimiter", 7), + ('["spam",', 'Expecting value', 8), + ('{', "Expecting property name enclosed in double quotes or '}'", 1), + ('{"', 'Unterminated string starting at', 1), + ('{"spam', 'Unterminated string starting at', 1), + ('{"spam"', "Expecting ':' delimiter", 7), + ('{"spam":', 'Expecting value', 8), + ('{"spam":42', "Expecting ',' delimiter", 10), + ('{"spam":42,', 'Expecting property name enclosed in double quotes', + 11), + ('"', 'Unterminated string starting at', 0), + ('"spam', 'Unterminated string starting at', 0), + ('[,', "Expecting value", 1), + ('--', 'Expecting value', 0), + ('"\x18d', "Invalid control character %r", 1), + ] + for data, msg, idx in test_cases: + try: + json.loads(data) + except json.JSONDecodeError: + e = sys.exc_info()[1] + self.assertEqual( + e.msg[:len(msg)], + msg, + "%r doesn't start with %r for %r" % (e.msg, msg, data)) + self.assertEqual( + e.pos, idx, + "pos %r != %r for %r" % (e.pos, idx, data)) + except Exception: + e = sys.exc_info()[1] + self.fail("Unexpected exception raised %r %s" % (e, e)) + else: + self.fail("Unexpected success parsing '%r'" % (data,)) diff --git a/lib/simplejson/tests/test_float.py b/lib/simplejson/tests/test_float.py new file mode 100644 index 0000000..494d999 --- /dev/null +++ b/lib/simplejson/tests/test_float.py @@ -0,0 +1,38 @@ +import math +from unittest import TestCase +from simplejson.compat import long_type, text_type +import simplejson as json +from simplejson.decoder import NaN, PosInf, NegInf + +class TestFloat(TestCase): + def test_degenerates_allow(self): + for inf in (PosInf, NegInf): + self.assertEqual(json.loads(json.dumps(inf, allow_nan=True), allow_nan=True), inf) + # Python 2.5 doesn't have math.isnan + nan = json.loads(json.dumps(NaN, allow_nan=True), allow_nan=True) + self.assertTrue((0 + nan) != nan) + + def test_degenerates_ignore(self): + for f in (PosInf, NegInf, NaN): + self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None) + + def test_degenerates_deny(self): + for f in (PosInf, NegInf, NaN): + self.assertRaises(ValueError, json.dumps, f, allow_nan=False) + for s in ('Infinity', '-Infinity', 'NaN'): + self.assertRaises(ValueError, json.loads, s, allow_nan=False) + self.assertRaises(ValueError, json.loads, s) + + def test_floats(self): + for num in [1617161771.7650001, math.pi, math.pi**100, + math.pi**-100, 3.1]: + self.assertEqual(float(json.dumps(num)), num) + self.assertEqual(json.loads(json.dumps(num)), num) + self.assertEqual(json.loads(text_type(json.dumps(num))), num) + + def test_ints(self): + for num in [1, long_type(1), 1<<32, 1<<64]: + self.assertEqual(json.dumps(num), str(num)) + self.assertEqual(int(json.dumps(num)), num) + self.assertEqual(json.loads(json.dumps(num)), num) + self.assertEqual(json.loads(text_type(json.dumps(num))), num) diff --git a/lib/simplejson/tests/test_for_json.py b/lib/simplejson/tests/test_for_json.py new file mode 100644 index 0000000..4c153fd --- /dev/null +++ b/lib/simplejson/tests/test_for_json.py @@ -0,0 +1,97 @@ +import unittest +import simplejson as json + + +class ForJson(object): + def for_json(self): + return {'for_json': 1} + + +class NestedForJson(object): + def for_json(self): + return {'nested': ForJson()} + + +class ForJsonList(object): + def for_json(self): + return ['list'] + + +class DictForJson(dict): + def for_json(self): + return {'alpha': 1} + + +class ListForJson(list): + def for_json(self): + return ['list'] + + +class TestForJson(unittest.TestCase): + def assertRoundTrip(self, obj, other, for_json=True): + if for_json is None: + # None will use the default + s = json.dumps(obj) + else: + s = json.dumps(obj, for_json=for_json) + self.assertEqual( + json.loads(s), + other) + + def test_for_json_encodes_stand_alone_object(self): + self.assertRoundTrip( + ForJson(), + ForJson().for_json()) + + def test_for_json_encodes_object_nested_in_dict(self): + self.assertRoundTrip( + {'hooray': ForJson()}, + {'hooray': ForJson().for_json()}) + + def test_for_json_encodes_object_nested_in_list_within_dict(self): + self.assertRoundTrip( + {'list': [0, ForJson(), 2, 3]}, + {'list': [0, ForJson().for_json(), 2, 3]}) + + def test_for_json_encodes_object_nested_within_object(self): + self.assertRoundTrip( + NestedForJson(), + {'nested': {'for_json': 1}}) + + def test_for_json_encodes_list(self): + self.assertRoundTrip( + ForJsonList(), + ForJsonList().for_json()) + + def test_for_json_encodes_list_within_object(self): + self.assertRoundTrip( + {'nested': ForJsonList()}, + {'nested': ForJsonList().for_json()}) + + def test_for_json_encodes_dict_subclass(self): + self.assertRoundTrip( + DictForJson(a=1), + DictForJson(a=1).for_json()) + + def test_for_json_encodes_list_subclass(self): + self.assertRoundTrip( + ListForJson(['l']), + ListForJson(['l']).for_json()) + + def test_for_json_ignored_if_not_true_with_dict_subclass(self): + for for_json in (None, False): + self.assertRoundTrip( + DictForJson(a=1), + {'a': 1}, + for_json=for_json) + + def test_for_json_ignored_if_not_true_with_list_subclass(self): + for for_json in (None, False): + self.assertRoundTrip( + ListForJson(['l']), + ['l'], + for_json=for_json) + + def test_raises_typeerror_if_for_json_not_true_with_object(self): + self.assertRaises(TypeError, json.dumps, ForJson()) + self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False) diff --git a/lib/simplejson/tests/test_indent.py b/lib/simplejson/tests/test_indent.py new file mode 100644 index 0000000..32326a6 --- /dev/null +++ b/lib/simplejson/tests/test_indent.py @@ -0,0 +1,86 @@ +from unittest import TestCase +import textwrap + +import simplejson as json +from simplejson.compat import StringIO + +class TestIndent(TestCase): + def test_indent(self): + h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', + 'i-vhbjkhnth', + {'nifty': 87}, {'field': 'yes', 'morefield': False} ] + + expect = textwrap.dedent("""\ + [ + \t[ + \t\t"blorpie" + \t], + \t[ + \t\t"whoops" + \t], + \t[], + \t"d-shtaeou", + \t"d-nthiouh", + \t"i-vhbjkhnth", + \t{ + \t\t"nifty": 87 + \t}, + \t{ + \t\t"field": "yes", + \t\t"morefield": false + \t} + ]""") + + + d1 = json.dumps(h) + d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': ')) + d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': ')) + d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': ')) + + h1 = json.loads(d1) + h2 = json.loads(d2) + h3 = json.loads(d3) + h4 = json.loads(d4) + + self.assertEqual(h1, h) + self.assertEqual(h2, h) + self.assertEqual(h3, h) + self.assertEqual(h4, h) + self.assertEqual(d3, expect.replace('\t', ' ')) + self.assertEqual(d4, expect.replace('\t', ' ')) + # NOTE: Python 2.4 textwrap.dedent converts tabs to spaces, + # so the following is expected to fail. Python 2.4 is not a + # supported platform in simplejson 2.1.0+. + self.assertEqual(d2, expect) + + def test_indent0(self): + h = {3: 1} + def check(indent, expected): + d1 = json.dumps(h, indent=indent) + self.assertEqual(d1, expected) + + sio = StringIO() + json.dump(h, sio, indent=indent) + self.assertEqual(sio.getvalue(), expected) + + # indent=0 should emit newlines + check(0, '{\n"3": 1\n}') + # indent=None is more compact + check(None, '{"3": 1}') + + def test_separators(self): + lst = [1,2,3,4] + expect = '[\n1,\n2,\n3,\n4\n]' + expect_spaces = '[\n1, \n2, \n3, \n4\n]' + # Ensure that separators still works + self.assertEqual( + expect_spaces, + json.dumps(lst, indent=0, separators=(', ', ': '))) + # Force the new defaults + self.assertEqual( + expect, + json.dumps(lst, indent=0, separators=(',', ': '))) + # Added in 2.1.4 + self.assertEqual( + expect, + json.dumps(lst, indent=0)) diff --git a/lib/simplejson/tests/test_item_sort_key.py b/lib/simplejson/tests/test_item_sort_key.py new file mode 100644 index 0000000..c913304 --- /dev/null +++ b/lib/simplejson/tests/test_item_sort_key.py @@ -0,0 +1,27 @@ +from unittest import TestCase + +import simplejson as json +from operator import itemgetter + +class TestItemSortKey(TestCase): + def test_simple_first(self): + a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} + self.assertEqual( + '{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}', + json.dumps(a, item_sort_key=json.simple_first)) + + def test_case(self): + a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} + self.assertEqual( + '{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', + json.dumps(a, item_sort_key=itemgetter(0))) + self.assertEqual( + '{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', + json.dumps(a, item_sort_key=lambda kv: kv[0].lower())) + + def test_item_sort_key_value(self): + # https://github.com/simplejson/simplejson/issues/173 + a = {'a': 1, 'b': 0} + self.assertEqual( + '{"b": 0, "a": 1}', + json.dumps(a, item_sort_key=lambda kv: kv[1])) diff --git a/lib/simplejson/tests/test_iterable.py b/lib/simplejson/tests/test_iterable.py new file mode 100644 index 0000000..995e1fd --- /dev/null +++ b/lib/simplejson/tests/test_iterable.py @@ -0,0 +1,31 @@ +import unittest +from simplejson.compat import StringIO + +import simplejson as json + +def iter_dumps(obj, **kw): + return ''.join(json.JSONEncoder(**kw).iterencode(obj)) + +def sio_dump(obj, **kw): + sio = StringIO() + json.dumps(obj, **kw) + return sio.getvalue() + +class TestIterable(unittest.TestCase): + def test_iterable(self): + for l in ([], [1], [1, 2], [1, 2, 3]): + for opts in [{}, {'indent': 2}]: + for dumps in (json.dumps, iter_dumps, sio_dump): + expect = dumps(l, **opts) + default_expect = dumps(sum(l), **opts) + # Default is False + self.assertRaises(TypeError, dumps, iter(l), **opts) + self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts) + self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts)) + # Ensure that the "default" gets called + self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts)) + self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts)) + # Ensure that the "default" does not get called + self.assertEqual( + expect, + dumps(iter(l), iterable_as_array=True, default=sum, **opts)) diff --git a/lib/simplejson/tests/test_namedtuple.py b/lib/simplejson/tests/test_namedtuple.py new file mode 100644 index 0000000..e3aa310 --- /dev/null +++ b/lib/simplejson/tests/test_namedtuple.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +import unittest +import simplejson as json +from simplejson.compat import StringIO + +try: + from unittest import mock +except ImportError: + mock = None + +try: + from collections import namedtuple +except ImportError: + class Value(tuple): + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def _asdict(self): + return {'value': self[0]} + class Point(tuple): + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def _asdict(self): + return {'x': self[0], 'y': self[1]} +else: + Value = namedtuple('Value', ['value']) + Point = namedtuple('Point', ['x', 'y']) + +class DuckValue(object): + def __init__(self, *args): + self.value = Value(*args) + + def _asdict(self): + return self.value._asdict() + +class DuckPoint(object): + def __init__(self, *args): + self.point = Point(*args) + + def _asdict(self): + return self.point._asdict() + +class DeadDuck(object): + _asdict = None + +class DeadDict(dict): + _asdict = None + +CONSTRUCTORS = [ + lambda v: v, + lambda v: [v], + lambda v: [{'key': v}], +] + +class TestNamedTuple(unittest.TestCase): + def test_namedtuple_dumps(self): + for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: + d = v._asdict() + self.assertEqual(d, json.loads(json.dumps(v))) + self.assertEqual( + d, + json.loads(json.dumps(v, namedtuple_as_object=True))) + self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False))) + self.assertEqual( + d, + json.loads(json.dumps(v, namedtuple_as_object=True, + tuple_as_array=False))) + + def test_namedtuple_dumps_false(self): + for v in [Value(1), Point(1, 2)]: + l = list(v) + self.assertEqual( + l, + json.loads(json.dumps(v, namedtuple_as_object=False))) + self.assertRaises(TypeError, json.dumps, v, + tuple_as_array=False, namedtuple_as_object=False) + + def test_namedtuple_dump(self): + for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: + d = v._asdict() + sio = StringIO() + json.dump(v, sio) + self.assertEqual(d, json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=True) + self.assertEqual( + d, + json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, tuple_as_array=False) + self.assertEqual(d, json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=True, + tuple_as_array=False) + self.assertEqual( + d, + json.loads(sio.getvalue())) + + def test_namedtuple_dump_false(self): + for v in [Value(1), Point(1, 2)]: + l = list(v) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=False) + self.assertEqual( + l, + json.loads(sio.getvalue())) + self.assertRaises(TypeError, json.dump, v, StringIO(), + tuple_as_array=False, namedtuple_as_object=False) + + def test_asdict_not_callable_dump(self): + for f in CONSTRUCTORS: + self.assertRaises( + TypeError, + json.dump, + f(DeadDuck()), + StringIO(), + namedtuple_as_object=True + ) + sio = StringIO() + json.dump(f(DeadDict()), sio, namedtuple_as_object=True) + self.assertEqual( + json.dumps(f({})), + sio.getvalue()) + self.assertRaises( + TypeError, + json.dump, + f(Value), + StringIO(), + namedtuple_as_object=True + ) + + def test_asdict_not_callable_dumps(self): + for f in CONSTRUCTORS: + self.assertRaises(TypeError, + json.dumps, f(DeadDuck()), namedtuple_as_object=True) + self.assertRaises( + TypeError, + json.dumps, + f(Value), + namedtuple_as_object=True + ) + self.assertEqual( + json.dumps(f({})), + json.dumps(f(DeadDict()), namedtuple_as_object=True)) + + def test_asdict_unbound_method_dumps(self): + for f in CONSTRUCTORS: + self.assertEqual( + json.dumps(f(Value), default=lambda v: v.__name__), + json.dumps(f(Value.__name__)) + ) + + def test_asdict_does_not_return_dict(self): + if not mock: + if hasattr(unittest, "SkipTest"): + raise unittest.SkipTest("unittest.mock required") + else: + print("unittest.mock not available") + return + fake = mock.Mock() + self.assertTrue(hasattr(fake, '_asdict')) + self.assertTrue(callable(fake._asdict)) + self.assertFalse(isinstance(fake._asdict(), dict)) + # https://github.com/simplejson/simplejson/pull/284 + # when running under a debug build of CPython (COPTS=-UNDEBUG) + # a C assertion could fire due to an unchecked error of an PyDict + # API call on a non-dict internally in _speedups.c. Without a debug + # build of CPython this test likely passes either way despite the + # potential for internal data corruption. Getting it to crash in + # a debug build is not always easy either as it requires an + # assert(!PyErr_Occurred()) that could fire later on. + with self.assertRaises(TypeError): + json.dumps({23: fake}, namedtuple_as_object=True, for_json=False) diff --git a/lib/simplejson/tests/test_pass1.py b/lib/simplejson/tests/test_pass1.py new file mode 100644 index 0000000..7482833 --- /dev/null +++ b/lib/simplejson/tests/test_pass1.py @@ -0,0 +1,71 @@ +from unittest import TestCase + +import simplejson as json + +# from http://json.org/JSON_checker/test/pass1.json +JSON = r''' +[ + "JSON Test Pattern pass1", + {"object with 1 member":["array with 1 element"]}, + {}, + [], + -42, + true, + false, + null, + { + "integer": 1234567890, + "real": -9876.543210, + "e": 0.123456789e-12, + "E": 1.234567890E+34, + "": 23456789012E66, + "zero": 0, + "one": 1, + "space": " ", + "quote": "\"", + "backslash": "\\", + "controls": "\b\f\n\r\t", + "slash": "/ & \/", + "alpha": "abcdefghijklmnopqrstuvwyz", + "ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ", + "digit": "0123456789", + "special": "`1~!@#$%^&*()_+-={':[,]}|;.?", + "hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A", + "true": true, + "false": false, + "null": null, + "array":[ ], + "object":{ }, + "address": "50 St. James Street", + "url": "http://www.JSON.org/", + "comment": "// /* */": " ", + " s p a c e d " :[1,2 , 3 + +, + +4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7], + "jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}", + "quotes": "" \u0022 %22 0x22 034 "", + "\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?" +: "A key can be any string" + }, + 0.5 ,98.6 +, +99.44 +, + +1066, +1e1, +0.1e1, +1e-1, +1e00,2e+00,2e-00 +,"rosebud"] +''' + +class TestPass1(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_pass2.py b/lib/simplejson/tests/test_pass2.py new file mode 100644 index 0000000..5c8e9ef --- /dev/null +++ b/lib/simplejson/tests/test_pass2.py @@ -0,0 +1,14 @@ +from unittest import TestCase +import simplejson as json + +# from http://json.org/JSON_checker/test/pass2.json +JSON = r''' +[[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]] +''' + +class TestPass2(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_pass3.py b/lib/simplejson/tests/test_pass3.py new file mode 100644 index 0000000..7d9cd8c --- /dev/null +++ b/lib/simplejson/tests/test_pass3.py @@ -0,0 +1,20 @@ +from unittest import TestCase + +import simplejson as json + +# from http://json.org/JSON_checker/test/pass3.json +JSON = r''' +{ + "JSON Test Pattern pass3": { + "The outermost value": "must be an object or array.", + "In this test": "It is an object." + } +} +''' + +class TestPass3(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_raw_json.py b/lib/simplejson/tests/test_raw_json.py new file mode 100644 index 0000000..aaa2724 --- /dev/null +++ b/lib/simplejson/tests/test_raw_json.py @@ -0,0 +1,47 @@ +import unittest +import simplejson as json + +dct1 = { + 'key1': 'value1' +} + +dct2 = { + 'key2': 'value2', + 'd1': dct1 +} + +dct3 = { + 'key2': 'value2', + 'd1': json.dumps(dct1) +} + +dct4 = { + 'key2': 'value2', + 'd1': json.RawJSON(json.dumps(dct1)) +} + + +class TestRawJson(unittest.TestCase): + + def test_normal_str(self): + self.assertNotEqual(json.dumps(dct2), json.dumps(dct3)) + + def test_raw_json_str(self): + self.assertEqual(json.dumps(dct2), json.dumps(dct4)) + self.assertEqual(dct2, json.loads(json.dumps(dct4))) + + def test_list(self): + self.assertEqual( + json.dumps([dct2]), + json.dumps([json.RawJSON(json.dumps(dct2))])) + self.assertEqual( + [dct2], + json.loads(json.dumps([json.RawJSON(json.dumps(dct2))]))) + + def test_direct(self): + self.assertEqual( + json.dumps(dct2), + json.dumps(json.RawJSON(json.dumps(dct2)))) + self.assertEqual( + dct2, + json.loads(json.dumps(json.RawJSON(json.dumps(dct2))))) diff --git a/lib/simplejson/tests/test_recursion.py b/lib/simplejson/tests/test_recursion.py new file mode 100644 index 0000000..76328fc --- /dev/null +++ b/lib/simplejson/tests/test_recursion.py @@ -0,0 +1,67 @@ +from unittest import TestCase + +import simplejson as json + +class JSONTestObject: + pass + + +class RecursiveJSONEncoder(json.JSONEncoder): + recurse = False + def default(self, o): + if o is JSONTestObject: + if self.recurse: + return [JSONTestObject] + else: + return 'JSONTestObject' + return json.JSONEncoder.default(o) + + +class TestRecursion(TestCase): + def test_listrecursion(self): + x = [] + x.append(x) + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on list recursion") + x = [] + y = [x] + x.append(y) + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on alternating list recursion") + y = [] + x = [y, y] + # ensure that the marker is cleared + json.dumps(x) + + def test_dictrecursion(self): + x = {} + x["test"] = x + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on dict recursion") + x = {} + y = {"a": x, "b": x} + # ensure that the marker is cleared + json.dumps(y) + + def test_defaultrecursion(self): + enc = RecursiveJSONEncoder() + self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"') + enc.recurse = True + try: + enc.encode(JSONTestObject) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on default recursion") diff --git a/lib/simplejson/tests/test_scanstring.py b/lib/simplejson/tests/test_scanstring.py new file mode 100644 index 0000000..89d83d6 --- /dev/null +++ b/lib/simplejson/tests/test_scanstring.py @@ -0,0 +1,200 @@ +import sys +from unittest import TestCase + +import simplejson as json +import simplejson.decoder +from simplejson.compat import b, PY3 + +class TestScanString(TestCase): + # The bytes type is intentionally not used in most of these tests + # under Python 3 because the decoder immediately coerces to str before + # calling scanstring. In Python 2 we are testing the code paths + # for both unicode and str. + # + # The reason this is done is because Python 3 would require + # entirely different code paths for parsing bytes and str. + # + def test_py_scanstring(self): + self._test_scanstring(simplejson.decoder.py_scanstring) + + def test_c_scanstring(self): + if not simplejson.decoder.c_scanstring: + return + self._test_scanstring(simplejson.decoder.c_scanstring) + + self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str)) + + def _test_scanstring(self, scanstring): + if sys.maxunicode == 65535: + self.assertEqual( + scanstring(u'"z\U0001d120x"', 1, None, True), + (u'z\U0001d120x', 6)) + else: + self.assertEqual( + scanstring(u'"z\U0001d120x"', 1, None, True), + (u'z\U0001d120x', 5)) + + self.assertEqual( + scanstring('"\\u007b"', 1, None, True), + (u'{', 8)) + + self.assertEqual( + scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True), + (u'A JSON payload should be an object or array, not a string.', 60)) + + self.assertEqual( + scanstring('["Unclosed array"', 2, None, True), + (u'Unclosed array', 17)) + + self.assertEqual( + scanstring('["extra comma",]', 2, None, True), + (u'extra comma', 14)) + + self.assertEqual( + scanstring('["double extra comma",,]', 2, None, True), + (u'double extra comma', 21)) + + self.assertEqual( + scanstring('["Comma after the close"],', 2, None, True), + (u'Comma after the close', 24)) + + self.assertEqual( + scanstring('["Extra close"]]', 2, None, True), + (u'Extra close', 14)) + + self.assertEqual( + scanstring('{"Extra comma": true,}', 2, None, True), + (u'Extra comma', 14)) + + self.assertEqual( + scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True), + (u'Extra value after close', 26)) + + self.assertEqual( + scanstring('{"Illegal expression": 1 + 2}', 2, None, True), + (u'Illegal expression', 21)) + + self.assertEqual( + scanstring('{"Illegal invocation": alert()}', 2, None, True), + (u'Illegal invocation', 21)) + + self.assertEqual( + scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True), + (u'Numbers cannot have leading zeroes', 37)) + + self.assertEqual( + scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True), + (u'Numbers cannot be hex', 24)) + + self.assertEqual( + scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True), + (u'Too deep', 30)) + + self.assertEqual( + scanstring('{"Missing colon" null}', 2, None, True), + (u'Missing colon', 16)) + + self.assertEqual( + scanstring('{"Double colon":: null}', 2, None, True), + (u'Double colon', 15)) + + self.assertEqual( + scanstring('{"Comma instead of colon", null}', 2, None, True), + (u'Comma instead of colon', 25)) + + self.assertEqual( + scanstring('["Colon instead of comma": false]', 2, None, True), + (u'Colon instead of comma', 25)) + + self.assertEqual( + scanstring('["Bad value", truth]', 2, None, True), + (u'Bad value', 12)) + + for c in map(chr, range(0x00, 0x1f)): + self.assertEqual( + scanstring(c + '"', 0, None, False), + (c, 2)) + self.assertRaises( + ValueError, + scanstring, c + '"', 0, None, True) + + self.assertRaises(ValueError, scanstring, '', 0, None, True) + self.assertRaises(ValueError, scanstring, 'a', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True) + if sys.maxunicode > 65535: + self.assertRaises(ValueError, + scanstring, '\\ud834\\u"', 0, None, True) + self.assertRaises(ValueError, + scanstring, '\\ud834\\x0123"', 0, None, True) + + self.assertRaises(json.JSONDecodeError, scanstring, '\\u-123"', 0, None, True) + # SJ-PT-23-01: Invalid Handling of Broken Unicode Escape Sequences + self.assertRaises(json.JSONDecodeError, scanstring, '\\u EDD"', 0, None, True) + + def test_issue3623(self): + self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1, + "xxx") + self.assertRaises(UnicodeDecodeError, + json.encoder.encode_basestring_ascii, b("xx\xff")) + + def test_overflow(self): + # Python 2.5 does not have maxsize, Python 3 does not have maxint + maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None)) + assert maxsize is not None + self.assertRaises(OverflowError, json.decoder.scanstring, "xxx", + maxsize + 1) + + def test_surrogates(self): + scanstring = json.decoder.scanstring + + def assertScan(given, expect, test_utf8=True): + givens = [given] + if not PY3 and test_utf8: + givens.append(given.encode('utf8')) + for given in givens: + (res, count) = scanstring(given, 1, None, True) + self.assertEqual(len(given), count) + self.assertEqual(res, expect) + + assertScan( + u'"z\\ud834\\u0079x"', + u'z\ud834yx') + assertScan( + u'"z\\ud834\\udd20x"', + u'z\U0001d120x') + assertScan( + u'"z\\ud834\\ud834\\udd20x"', + u'z\ud834\U0001d120x') + assertScan( + u'"z\\ud834x"', + u'z\ud834x') + assertScan( + u'"z\\udd20x"', + u'z\udd20x') + assertScan( + u'"z\ud834x"', + u'z\ud834x') + # It may look strange to join strings together, but Python is drunk. + # https://gist.github.com/etrepum/5538443 + assertScan( + u'"z\\ud834\udd20x12345"', + u''.join([u'z\ud834', u'\udd20x12345'])) + assertScan( + u'"z\ud834\\udd20x"', + u''.join([u'z\ud834', u'\udd20x'])) + # these have different behavior given UTF8 input, because the surrogate + # pair may be joined (in maxunicode > 65535 builds) + assertScan( + u''.join([u'"z\ud834', u'\udd20x"']), + u''.join([u'z\ud834', u'\udd20x']), + test_utf8=False) + + self.assertRaises(ValueError, + scanstring, u'"z\\ud83x"', 1, None, True) + self.assertRaises(ValueError, + scanstring, u'"z\\ud834\\udd2x"', 1, None, True) diff --git a/lib/simplejson/tests/test_separators.py b/lib/simplejson/tests/test_separators.py new file mode 100644 index 0000000..5a21176 --- /dev/null +++ b/lib/simplejson/tests/test_separators.py @@ -0,0 +1,42 @@ +import textwrap +from unittest import TestCase + +import simplejson as json + + +class TestSeparators(TestCase): + def test_separators(self): + h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth', + {'nifty': 87}, {'field': 'yes', 'morefield': False} ] + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ] , + [ + "whoops" + ] , + [] , + "d-shtaeou" , + "d-nthiouh" , + "i-vhbjkhnth" , + { + "nifty" : 87 + } , + { + "field" : "yes" , + "morefield" : false + } + ]""") + + + d1 = json.dumps(h) + d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : ')) + + h1 = json.loads(d1) + h2 = json.loads(d2) + + self.assertEqual(h1, h) + self.assertEqual(h2, h) + self.assertEqual(d2, expect) diff --git a/lib/simplejson/tests/test_speedups.py b/lib/simplejson/tests/test_speedups.py new file mode 100644 index 0000000..c0a66f3 --- /dev/null +++ b/lib/simplejson/tests/test_speedups.py @@ -0,0 +1,114 @@ +from __future__ import with_statement + +import sys +import unittest +from unittest import TestCase + +import simplejson +from simplejson import encoder, decoder, scanner +from simplejson.compat import PY3, long_type, b + + +def has_speedups(): + return encoder.c_make_encoder is not None + + +def skip_if_speedups_missing(func): + def wrapper(*args, **kwargs): + if not has_speedups(): + if hasattr(unittest, 'SkipTest'): + raise unittest.SkipTest("C Extension not available") + else: + sys.stdout.write("C Extension not available") + return + return func(*args, **kwargs) + + return wrapper + + +class BadBool: + def __bool__(self): + 1/0 + __nonzero__ = __bool__ + + +class TestDecode(TestCase): + @skip_if_speedups_missing + def test_make_scanner(self): + self.assertRaises(AttributeError, scanner.c_make_scanner, 1) + + @skip_if_speedups_missing + def test_bad_bool_args(self): + def test(value): + decoder.JSONDecoder(strict=BadBool()).decode(value) + self.assertRaises(ZeroDivisionError, test, '""') + self.assertRaises(ZeroDivisionError, test, '{}') + if not PY3: + self.assertRaises(ZeroDivisionError, test, u'""') + self.assertRaises(ZeroDivisionError, test, u'{}') + +class TestEncode(TestCase): + @skip_if_speedups_missing + def test_make_encoder(self): + self.assertRaises( + TypeError, + encoder.c_make_encoder, + None, + ("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7" + "\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"), + None + ) + + @skip_if_speedups_missing + def test_bad_str_encoder(self): + # Issue #31505: There shouldn't be an assertion failure in case + # c_make_encoder() receives a bad encoder() argument. + import decimal + def bad_encoder1(*args): + return None + enc = encoder.c_make_encoder( + None, lambda obj: str(obj), + bad_encoder1, None, ': ', ', ', + False, False, False, {}, False, False, False, + None, None, 'utf-8', False, False, decimal.Decimal, False) + self.assertRaises(TypeError, enc, 'spam', 4) + self.assertRaises(TypeError, enc, {'spam': 42}, 4) + + def bad_encoder2(*args): + 1/0 + enc = encoder.c_make_encoder( + None, lambda obj: str(obj), + bad_encoder2, None, ': ', ', ', + False, False, False, {}, False, False, False, + None, None, 'utf-8', False, False, decimal.Decimal, False) + self.assertRaises(ZeroDivisionError, enc, 'spam', 4) + + @skip_if_speedups_missing + def test_bad_bool_args(self): + def test(name): + encoder.JSONEncoder(**{name: BadBool()}).encode({}) + self.assertRaises(ZeroDivisionError, test, 'skipkeys') + self.assertRaises(ZeroDivisionError, test, 'ensure_ascii') + self.assertRaises(ZeroDivisionError, test, 'check_circular') + self.assertRaises(ZeroDivisionError, test, 'allow_nan') + self.assertRaises(ZeroDivisionError, test, 'sort_keys') + self.assertRaises(ZeroDivisionError, test, 'use_decimal') + self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object') + self.assertRaises(ZeroDivisionError, test, 'tuple_as_array') + self.assertRaises(ZeroDivisionError, test, 'bigint_as_string') + self.assertRaises(ZeroDivisionError, test, 'for_json') + self.assertRaises(ZeroDivisionError, test, 'ignore_nan') + self.assertRaises(ZeroDivisionError, test, 'iterable_as_array') + + @skip_if_speedups_missing + def test_int_as_string_bitcount_overflow(self): + long_count = long_type(2)**32+31 + def test(): + encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0) + self.assertRaises((TypeError, OverflowError), test) + + if PY3: + @skip_if_speedups_missing + def test_bad_encoding(self): + with self.assertRaises(UnicodeEncodeError): + encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123}) diff --git a/lib/simplejson/tests/test_str_subclass.py b/lib/simplejson/tests/test_str_subclass.py new file mode 100644 index 0000000..6bdd41f --- /dev/null +++ b/lib/simplejson/tests/test_str_subclass.py @@ -0,0 +1,21 @@ +from unittest import TestCase + +import simplejson +from simplejson.compat import text_type + +# Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144 +class WonkyTextSubclass(text_type): + def __getslice__(self, start, end): + return self.__class__('not what you wanted!') + +class TestStrSubclass(TestCase): + def test_dump_load(self): + for s in ['', '"hello"', 'text', u'\u005c']: + self.assertEqual( + s, + simplejson.loads(simplejson.dumps(WonkyTextSubclass(s)))) + + self.assertEqual( + s, + simplejson.loads(simplejson.dumps(WonkyTextSubclass(s), + ensure_ascii=False))) diff --git a/lib/simplejson/tests/test_subclass.py b/lib/simplejson/tests/test_subclass.py new file mode 100644 index 0000000..a9ae318 --- /dev/null +++ b/lib/simplejson/tests/test_subclass.py @@ -0,0 +1,37 @@ +from unittest import TestCase +import simplejson as json + +from decimal import Decimal + +class AlternateInt(int): + def __repr__(self): + return 'invalid json' + __str__ = __repr__ + + +class AlternateFloat(float): + def __repr__(self): + return 'invalid json' + __str__ = __repr__ + + +# class AlternateDecimal(Decimal): +# def __repr__(self): +# return 'invalid json' + + +class TestSubclass(TestCase): + def test_int(self): + self.assertEqual(json.dumps(AlternateInt(1)), '1') + self.assertEqual(json.dumps(AlternateInt(-1)), '-1') + self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1}) + + def test_float(self): + self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0') + self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0') + self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1}) + + # NOTE: Decimal subclasses are not supported as-is + # def test_decimal(self): + # self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0') + # self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0') diff --git a/lib/simplejson/tests/test_tool.py b/lib/simplejson/tests/test_tool.py new file mode 100644 index 0000000..59807ec --- /dev/null +++ b/lib/simplejson/tests/test_tool.py @@ -0,0 +1,114 @@ +from __future__ import with_statement +import os +import sys +import textwrap +import unittest +import subprocess +import tempfile +try: + # Python 3.x + from test.support import strip_python_stderr +except ImportError: + # Python 2.6+ + try: + from test.test_support import strip_python_stderr + except ImportError: + # Python 2.5 + import re + def strip_python_stderr(stderr): + return re.sub( + r"\[\d+ refs\]\r?\n?$".encode(), + "".encode(), + stderr).strip() + +def open_temp_file(): + if sys.version_info >= (2, 6): + file = tempfile.NamedTemporaryFile(delete=False) + filename = file.name + else: + fd, filename = tempfile.mkstemp() + file = os.fdopen(fd, 'w+b') + return file, filename + +class TestTool(unittest.TestCase): + data = """ + + [["blorpie"],[ "whoops" ] , [ + ],\t"d-shtaeou",\r"d-nthiouh", + "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" + :"yes"} ] + """ + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ], + [ + "whoops" + ], + [], + "d-shtaeou", + "d-nthiouh", + "i-vhbjkhnth", + { + "nifty": 87 + }, + { + "field": "yes", + "morefield": false + } + ] + """) + + def runTool(self, args=None, data=None): + argv = [sys.executable, '-m', 'simplejson.tool'] + if args: + argv.extend(args) + proc = subprocess.Popen(argv, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + out, err = proc.communicate(data) + self.assertEqual(strip_python_stderr(err), ''.encode()) + self.assertEqual(proc.returncode, 0) + return out.decode('utf8').splitlines() + + def test_stdin_stdout(self): + self.assertEqual( + self.runTool(data=self.data.encode()), + self.expect.splitlines()) + + def test_infile_stdout(self): + infile, infile_name = open_temp_file() + try: + infile.write(self.data.encode()) + infile.close() + self.assertEqual( + self.runTool(args=[infile_name]), + self.expect.splitlines()) + finally: + os.unlink(infile_name) + + def test_infile_outfile(self): + infile, infile_name = open_temp_file() + try: + infile.write(self.data.encode()) + infile.close() + # outfile will get overwritten by tool, so the delete + # may not work on some platforms. Do it manually. + outfile, outfile_name = open_temp_file() + try: + outfile.close() + self.assertEqual( + self.runTool(args=[infile_name, outfile_name]), + []) + with open(outfile_name, 'rb') as f: + self.assertEqual( + f.read().decode('utf8').splitlines(), + self.expect.splitlines() + ) + finally: + os.unlink(outfile_name) + finally: + os.unlink(infile_name) diff --git a/lib/simplejson/tests/test_tuple.py b/lib/simplejson/tests/test_tuple.py new file mode 100644 index 0000000..94ea139 --- /dev/null +++ b/lib/simplejson/tests/test_tuple.py @@ -0,0 +1,47 @@ +import unittest + +from simplejson.compat import StringIO +import simplejson as json + +class TestTuples(unittest.TestCase): + def test_tuple_array_dumps(self): + t = (1, 2, 3) + expect = json.dumps(list(t)) + # Default is True + self.assertEqual(expect, json.dumps(t)) + self.assertEqual(expect, json.dumps(t, tuple_as_array=True)) + self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False) + # Ensure that the "default" does not get called + self.assertEqual(expect, json.dumps(t, default=repr)) + self.assertEqual(expect, json.dumps(t, tuple_as_array=True, + default=repr)) + # Ensure that the "default" gets called + self.assertEqual( + json.dumps(repr(t)), + json.dumps(t, tuple_as_array=False, default=repr)) + + def test_tuple_array_dump(self): + t = (1, 2, 3) + expect = json.dumps(list(t)) + # Default is True + sio = StringIO() + json.dump(t, sio) + self.assertEqual(expect, sio.getvalue()) + sio = StringIO() + json.dump(t, sio, tuple_as_array=True) + self.assertEqual(expect, sio.getvalue()) + self.assertRaises(TypeError, json.dump, t, StringIO(), + tuple_as_array=False) + # Ensure that the "default" does not get called + sio = StringIO() + json.dump(t, sio, default=repr) + self.assertEqual(expect, sio.getvalue()) + sio = StringIO() + json.dump(t, sio, tuple_as_array=True, default=repr) + self.assertEqual(expect, sio.getvalue()) + # Ensure that the "default" gets called + sio = StringIO() + json.dump(t, sio, tuple_as_array=False, default=repr) + self.assertEqual( + json.dumps(repr(t)), + sio.getvalue()) diff --git a/lib/simplejson/tests/test_unicode.py b/lib/simplejson/tests/test_unicode.py new file mode 100644 index 0000000..823aa9d --- /dev/null +++ b/lib/simplejson/tests/test_unicode.py @@ -0,0 +1,154 @@ +import sys +import codecs +from unittest import TestCase + +import simplejson as json +from simplejson.compat import unichr, text_type, b, BytesIO + +class TestUnicode(TestCase): + def test_encoding1(self): + encoder = json.JSONEncoder(encoding='utf-8') + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + s = u.encode('utf-8') + ju = encoder.encode(u) + js = encoder.encode(s) + self.assertEqual(ju, js) + + def test_encoding2(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + s = u.encode('utf-8') + ju = json.dumps(u, encoding='utf-8') + js = json.dumps(s, encoding='utf-8') + self.assertEqual(ju, js) + + def test_encoding3(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps(u) + self.assertEqual(j, '"\\u03b1\\u03a9"') + + def test_encoding4(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps([u]) + self.assertEqual(j, '["\\u03b1\\u03a9"]') + + def test_encoding5(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps(u, ensure_ascii=False) + self.assertEqual(j, u'"' + u + u'"') + + def test_encoding6(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps([u], ensure_ascii=False) + self.assertEqual(j, u'["' + u + u'"]') + + def test_big_unicode_encode(self): + u = u'\U0001d120' + self.assertEqual(json.dumps(u), '"\\ud834\\udd20"') + self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"') + + def test_big_unicode_decode(self): + u = u'z\U0001d120x' + self.assertEqual(json.loads('"' + u + '"'), u) + self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u) + + def test_unicode_decode(self): + for i in range(0, 0xd7ff): + u = unichr(i) + #s = '"\\u{0:04x}"'.format(i) + s = '"\\u%04x"' % (i,) + self.assertEqual(json.loads(s), u) + + def test_object_pairs_hook_with_unicode(self): + s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}' + p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4), + (u"qrt", 5), (u"pad", 6), (u"hoy", 7)] + self.assertEqual(json.loads(s), eval(s)) + self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p) + od = json.loads(s, object_pairs_hook=json.OrderedDict) + self.assertEqual(od, json.OrderedDict(p)) + self.assertEqual(type(od), json.OrderedDict) + # the object_pairs_hook takes priority over the object_hook + self.assertEqual(json.loads(s, + object_pairs_hook=json.OrderedDict, + object_hook=lambda x: None), + json.OrderedDict(p)) + + + def test_default_encoding(self): + self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')), + {'a': u'\xe9'}) + + def test_unicode_preservation(self): + self.assertEqual(type(json.loads(u'""')), text_type) + self.assertEqual(type(json.loads(u'"a"')), text_type) + self.assertEqual(type(json.loads(u'["a"]')[0]), text_type) + + def test_ensure_ascii_false_returns_unicode(self): + # http://code.google.com/p/simplejson/issues/detail?id=48 + self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type) + + def test_ensure_ascii_false_bytestring_encoding(self): + # http://code.google.com/p/simplejson/issues/detail?id=48 + doc1 = {u'quux': b('Arr\xc3\xaat sur images')} + doc2 = {u'quux': u'Arr\xeat sur images'} + doc_ascii = '{"quux": "Arr\\u00eat sur images"}' + doc_unicode = u'{"quux": "Arr\xeat sur images"}' + self.assertEqual(json.dumps(doc1), doc_ascii) + self.assertEqual(json.dumps(doc2), doc_ascii) + self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode) + self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode) + + def test_ensure_ascii_linebreak_encoding(self): + # http://timelessrepo.com/json-isnt-a-javascript-subset + s1 = u'\u2029\u2028' + s2 = s1.encode('utf8') + expect = '"\\u2029\\u2028"' + expect_non_ascii = u'"\u2029\u2028"' + self.assertEqual(json.dumps(s1), expect) + self.assertEqual(json.dumps(s2), expect) + self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii) + self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii) + + def test_invalid_escape_sequences(self): + # incomplete escape sequence + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234') + # invalid escape sequence + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"') + if sys.maxunicode > 65535: + # invalid escape sequence for low surrogate + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"') + + def test_ensure_ascii_still_works(self): + # in the ascii range, ensure that everything is the same + for c in map(unichr, range(0, 127)): + self.assertEqual( + json.dumps(c, ensure_ascii=False), + json.dumps(c)) + snowman = u'\N{SNOWMAN}' + self.assertEqual( + json.dumps(c, ensure_ascii=False), + '"' + c + '"') + + def test_strip_bom(self): + content = u"\u3053\u3093\u306b\u3061\u308f" + json_doc = codecs.BOM_UTF8 + b(json.dumps(content)) + self.assertEqual(json.load(BytesIO(json_doc)), content) + for doc in json_doc, json_doc.decode('utf8'): + self.assertEqual(json.loads(doc), content) diff --git a/lib/simplejson/tool.py b/lib/simplejson/tool.py new file mode 100644 index 0000000..c91a01d --- /dev/null +++ b/lib/simplejson/tool.py @@ -0,0 +1,42 @@ +r"""Command-line tool to validate and pretty-print JSON + +Usage:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 2 (char 2) + +""" +from __future__ import with_statement +import sys +import simplejson as json + +def main(): + if len(sys.argv) == 1: + infile = sys.stdin + outfile = sys.stdout + elif len(sys.argv) == 2: + infile = open(sys.argv[1], 'r') + outfile = sys.stdout + elif len(sys.argv) == 3: + infile = open(sys.argv[1], 'r') + outfile = open(sys.argv[2], 'w') + else: + raise SystemExit(sys.argv[0] + " [infile [outfile]]") + with infile: + try: + obj = json.load(infile, + object_pairs_hook=json.OrderedDict, + use_decimal=True) + except ValueError: + raise SystemExit(sys.exc_info()[1]) + with outfile: + json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) + outfile.write('\n') + + +if __name__ == '__main__': + main()