From 1972d794d9d06e854e02f49bfcf1fc9a27dfe76e Mon Sep 17 00:00:00 2001 From: Isi <86603298+Isi-dev@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:09:40 +0100 Subject: [PATCH 1/5] First commit --- .gitignore | 18 + README.md | 104 +++++ __init__.py | 3 + basicUniAnimateWorkflow.json | 427 +++++++++++++++++++++ environment.yaml | 236 ++++++++++++ modeldownloader.py | 2 + requirements.txt | 208 ++++++++++ run_align_pose.py | 709 +++++++++++++++++++++++++++++++++++ uniAnimate_Inference.py | 83 ++++ 9 files changed, 1790 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 __init__.py create mode 100644 basicUniAnimateWorkflow.json create mode 100644 environment.yaml create mode 100644 modeldownloader.py create mode 100644 requirements.txt create mode 100644 run_align_pose.py create mode 100644 uniAnimate_Inference.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..53200a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +*.pkl +*.pt +*.mov +*.pth +*.mov +*.npz +*.npy +*.boj +*.onnx +*.tar +*.bin +cache* +.DS_Store +*DS_Store +outputs/ +**/__pycache__ +***/__pycache__ +*/__pycache__ diff --git a/README.md b/README.md new file mode 100644 index 0000000..6755761 --- /dev/null +++ b/README.md @@ -0,0 +1,104 @@ + + + +
+ + +#This is my ComfyUi-windows implementation for the image animation project -> UniAnimate: Taming Unified Video Diffusion Models for Consistent Human Image Animation + +[🎨 Source Project Page](https://unianimate.github.io/) + +
+ + +## Getting Started + +The ComfyUI nodes created are "Align & Generate poses for UniAnimate" & "Animate image with UniAnimate" + +I used a ComfyUI_windows_portable to test the nodes in a Windows 10 OS with 16GB RAM & 12GB VRAM Nvidia Graphics Card + +Download or clone this repository and place it in ComfyUI_windows_portable\ComfyUI\custom_nodes\ + +You will need python>=3.9 in your ComfyUI Environment. +I tested the project with the following pytorch versions which you can install as follows + +``` +conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia + +``` + +Or + +``` +conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia + +``` + +If not installed, then: +pip install opencv-python +pip install pytorch_lightning +pip install lightning_utilities #if not installed +pip install lightning_fabric #if not installed +pip install torchmetrics +pip install xFormers = 0.0.20 or copy torch 2.0.1 and supporting libraries and xFormer from A1111 and place in your Environment\Lib\site-packages (or) +pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118 (for pytorch==2.3.1) +pip install oss2 +pip install einops +pip install args +pip install modelscope + + +Download the required models (Around 14GB) after installing modelscope : + +``` +python modeldownloader.py + +``` + +After downloading all the models, move them manually from 'checkpoints/iic/unianimate/' to the 'checkpoints' directory +Or move them via your command line interface: + +``` +python mv ./checkpoints/iic/unianimate/* ./checkpoints/ + +``` + +All the models should be in the '\Path-to-UniAnimate\checkpoints' folder as follows: + + +``` +./checkpoints/ +|---- dw-ll_ucoco_384.onnx +|---- open_clip_pytorch_model.bin +|---- unianimate_16f_32f_non_ema_223000.pth +|---- v2-1_512-ema-pruned.ckpt +└---- yolox_l.onnx + +``` + +You can now upload the workflow in your '\Path-to-UniAnimate\' folder which is titled 'basicUniAnimateWorkflow.json', install missing custom nodes with the ComfyUI Manager if necessary, upload a picture & video (You can use those in the 'assets' folder), and run! + + +**✔ Some tips**: + +- > In the 'Align & Generate poses for UniAnimate' node, the first frame in the target pose sequence is used to calculate the scale coefficient of the alignment. Therefore, if the first frame in the target pose sequence contains the entire face and pose (hand and foot), it can help obtain more accurate estimation and better video generation results. + +- > Generating 32 frames of video with a resolution of [512, 768] usually takes about 7 minutes. + + + +- > To run the "Animate image with UniAnimate" node, **~12G** GPU memory will be used. If your GPU is smaller than this, you can change the `max_frames: 32` to other values, e.g., 24, 16, and 8. + +You can also generate a video first, and then upload the last frame of the video as a pic to generate the next frames. + + +@article{ + project={ComfyUi-windows implementation for the image animation project -> UniAnimate: Taming Unified Video Diffusion Models for Consistent Human Image Animation}, + developer={Isimemen Omoifo Jnr}, + year={2024} +} + + +## Disclaimer + +I explicitly disclaim any responsibility for user-generated content. Users are solely liable for their actions while using these nodes and the generative model. I and the source project contributors have no legal affiliation with, nor accountability for, users' behaviors. It is imperative to use these nodes and the generative model responsibly, adhering to both ethical and legal standards. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..c24e4d6 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .uniAnimate_Inference import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] \ No newline at end of file diff --git a/basicUniAnimateWorkflow.json b/basicUniAnimateWorkflow.json new file mode 100644 index 0000000..7e19d87 --- /dev/null +++ b/basicUniAnimateWorkflow.json @@ -0,0 +1,427 @@ +{ + "last_node_id": 16, + "last_link_id": 18, + "nodes": [ + { + "id": 11, + "type": "VHS_LoadVideo", + "pos": [ + 580, + 465 + ], + "size": [ + 235.1999969482422, + 620.5687812408271 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [ + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 11 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "frame_count", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "audio", + "type": "VHS_AUDIO", + "links": null, + "shape": 3 + }, + { + "name": "video_info", + "type": "VHS_VIDEOINFO", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_LoadVideo" + }, + "widgets_values": { + "video": "0001-0032.mp4", + "force_rate": 0, + "force_size": "Disabled", + "custom_width": 512, + "custom_height": 512, + "frame_load_cap": 0, + "skip_first_frames": 0, + "select_every_nth": 1, + "choose video to upload": "image", + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "frame_load_cap": 0, + "skip_first_frames": 0, + "force_rate": 0, + "filename": "0001-0032.mp4", + "type": "input", + "format": "video/mp4", + "select_every_nth": 1 + } + } + } + }, + { + "id": 13, + "type": "PreviewImage", + "pos": [ + 900, + 207 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 12 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 10, + "type": "LoadImage", + "pos": [ + 538, + 95 + ], + "size": [ + 315, + 314.00000762939453 + ], + "flags": {}, + "order": 1, + "mode": 0, + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 10, + 15 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "LoadImage" + }, + "widgets_values": [ + "untitled.png", + "image" + ] + }, + { + "id": 12, + "type": "Gen_align_pose", + "pos": [ + 880, + 493 + ], + "size": { + "0": 267, + "1": 46 + }, + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [ + { + "name": "reference_image", + "type": "IMAGE", + "link": 10 + }, + { + "name": "video", + "type": "IMAGE", + "link": 11 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 12, + 16 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 14, + 17 + ], + "shape": 3, + "slot_index": 1 + } + ], + "properties": { + "Node name for S&R": "Gen_align_pose" + } + }, + { + "id": 14, + "type": "PreviewImage", + "pos": [ + 904, + 595 + ], + "size": [ + 210, + 246 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 14 + } + ], + "properties": { + "Node name for S&R": "PreviewImage" + } + }, + { + "id": 15, + "type": "UniAnimateImage", + "pos": [ + 1178, + 470 + ], + "size": { + "0": 315, + "1": 194 + }, + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "reference_image", + "type": "IMAGE", + "link": 15 + }, + { + "name": "ref_pose", + "type": "IMAGE", + "link": 16 + }, + { + "name": "pose_sequence", + "type": "IMAGE", + "link": 17 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 18 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "MASK", + "type": "MASK", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "UniAnimateImage" + }, + "widgets_values": [ + 30, + true, + 1, + 32, + 512 + ] + }, + { + "id": 16, + "type": "VHS_VideoCombine", + "pos": [ + 1522, + 225 + ], + "size": [ + 315, + 746.5 + ], + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 18 + }, + { + "name": "audio", + "type": "VHS_AUDIO", + "link": null + }, + { + "name": "meta_batch", + "type": "VHS_BatchManager", + "link": null + } + ], + "outputs": [ + { + "name": "Filenames", + "type": "VHS_FILENAMES", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "VHS_VideoCombine" + }, + "widgets_values": { + "frame_rate": 16, + "loop_count": 0, + "filename_prefix": "UniAnimate/vid", + "format": "video/h264-mp4", + "pix_fmt": "yuv420p", + "crf": 19, + "save_metadata": true, + "pingpong": false, + "save_output": true, + "videopreview": { + "hidden": false, + "paused": false, + "params": { + "filename": "vid_00001.mp4", + "subfolder": "UniAnimate", + "type": "output", + "format": "video/h264-mp4" + } + } + } + } + ], + "links": [ + [ + 10, + 10, + 0, + 12, + 0, + "IMAGE" + ], + [ + 11, + 11, + 0, + 12, + 1, + "IMAGE" + ], + [ + 12, + 12, + 0, + 13, + 0, + "IMAGE" + ], + [ + 14, + 12, + 1, + 14, + 0, + "IMAGE" + ], + [ + 15, + 10, + 0, + 15, + 0, + "IMAGE" + ], + [ + 16, + 12, + 0, + 15, + 1, + "IMAGE" + ], + [ + 17, + 12, + 1, + 15, + 2, + "IMAGE" + ], + [ + 18, + 15, + 0, + 16, + 0, + "IMAGE" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1, + "offset": [ + -94.68478436382884, + 8.13263156125987 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..a60b160 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,236 @@ +name: /mnt/user/miniconda3/envs/dtrans +channels: + - http://mirrors.aliyun.com/anaconda/pkgs/main + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2023.12.12=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.12=h7f8727e_0 + - pip=23.3.1=py39h06a4308_0 + - python=3.9.18=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.2.2=py39h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py39h06a4308_0 + - xz=5.4.5=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - aiofiles==23.2.1 + - aiohttp==3.9.1 + - aiosignal==1.3.1 + - aliyun-python-sdk-core==2.14.0 + - aliyun-python-sdk-kms==2.16.2 + - altair==5.2.0 + - annotated-types==0.6.0 + - antlr4-python3-runtime==4.9.3 + - anyio==4.2.0 + - argparse==1.4.0 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.2.0 + - automat==22.10.0 + - beartype==0.16.4 + - blessed==1.20.0 + - buildtools==1.0.6 + - causal-conv1d==1.1.3.post1 + - certifi==2023.11.17 + - cffi==1.16.0 + - chardet==5.2.0 + - charset-normalizer==3.3.2 + - clean-fid==0.1.35 + - click==8.1.7 + - clip==1.0 + - cmake==3.28.1 + - colorama==0.4.6 + - coloredlogs==15.0.1 + - constantly==23.10.4 + - contourpy==1.2.0 + - crcmod==1.7 + - cryptography==41.0.7 + - cycler==0.12.1 + - decorator==5.1.1 + - decord==0.6.0 + - diffusers==0.26.3 + - docopt==0.6.2 + - easydict==1.11 + - einops==0.7.0 + - exceptiongroup==1.2.0 + - executing==2.0.1 + - fairscale==0.4.13 + - fastapi==0.109.0 + - ffmpeg==1.4 + - ffmpy==0.3.1 + - filelock==3.13.1 + - flatbuffers==24.3.25 + - fonttools==4.47.2 + - frozenlist==1.4.1 + - fsspec==2023.12.2 + - ftfy==6.1.3 + - furl==2.1.3 + - gpustat==1.1.1 + - gradio==4.14.0 + - gradio-client==0.8.0 + - greenlet==3.0.3 + - h11==0.14.0 + - httpcore==1.0.2 + - httpx==0.26.0 + - huggingface-hub==0.20.2 + - humanfriendly==10.0 + - hyperlink==21.0.0 + - idna==3.6 + - imageio==2.33.1 + - imageio-ffmpeg==0.4.9 + - importlib-metadata==7.0.1 + - importlib-resources==6.1.1 + - incremental==22.10.0 + - ipdb==0.13.13 + - ipython==8.18.1 + - jedi==0.19.1 + - jinja2==3.1.3 + - jmespath==0.10.0 + - joblib==1.3.2 + - jsonschema==4.21.0 + - jsonschema-specifications==2023.12.1 + - kiwisolver==1.4.5 + - kornia==0.7.1 + - lazy-loader==0.3 + - lightning-utilities==0.10.0 + - lit==17.0.6 + - lpips==0.1.4 + - mamba-ssm==1.1.4 + - markdown-it-py==3.0.0 + - markupsafe==2.1.3 + - matplotlib==3.8.2 + - matplotlib-inline==0.1.6 + - mdurl==0.1.2 + - motion-vector-extractor==1.0.6 + - mpmath==1.3.0 + - multidict==6.0.4 + - mypy-extensions==1.0.0 + - networkx==3.2.1 + - ninja==1.11.1.1 + - numpy==1.26.3 + - nvidia-cublas-cu11==11.10.3.66 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu11==11.7.101 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu11==11.7.99 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu11==11.7.99 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu11==8.5.0.96 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu11==10.9.0.58 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu11==10.2.10.91 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu11==11.4.0.1 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu11==11.7.4.91 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-ml-py==12.535.133 + - nvidia-nccl-cu11==2.14.3 + - nvidia-nccl-cu12==2.19.3 + - nvidia-nvjitlink-cu12==12.3.101 + - nvidia-nvtx-cu11==11.7.91 + - nvidia-nvtx-cu12==12.1.105 + - omegaconf==2.3.0 + - onnxruntime==1.18.0 + - open-clip-torch==2.24.0 + - opencv-python==4.5.3.56 + - opencv-python-headless==4.9.0.80 + - orderedmultidict==1.0.1 + - orjson==3.9.10 + - oss2==2.18.4 + - packaging==23.2 + - pandas==2.1.4 + - parso==0.8.3 + - pexpect==4.9.0 + - pillow==10.2.0 + - piq==0.8.0 + - pkgconfig==1.5.5 + - prompt-toolkit==3.0.43 + - protobuf==4.25.2 + - psutil==5.9.8 + - ptflops==0.7.2.2 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pycparser==2.21 + - pycryptodome==3.20.0 + - pydantic==2.5.3 + - pydantic-core==2.14.6 + - pydub==0.25.1 + - pygments==2.17.2 + - pynvml==11.5.0 + - pyparsing==3.1.1 + - pyre-extensions==0.0.29 + - python-dateutil==2.8.2 + - python-multipart==0.0.6 + - pytorch-lightning==2.1.3 + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - redo==2.0.4 + - referencing==0.32.1 + - regex==2023.12.25 + - requests==2.31.0 + - rich==13.7.0 + - rotary-embedding-torch==0.5.3 + - rpds-py==0.17.1 + - ruff==0.2.0 + - safetensors==0.4.1 + - scikit-image==0.22.0 + - scikit-learn==1.4.0 + - scipy==1.11.4 + - semantic-version==2.10.0 + - sentencepiece==0.1.99 + - shellingham==1.5.4 + - simplejson==3.19.2 + - six==1.16.0 + - sk-video==1.1.10 + - sniffio==1.3.0 + - sqlalchemy==2.0.27 + - stack-data==0.6.3 + - starlette==0.35.1 + - sympy==1.12 + - thop==0.1.1-2209072238 + - threadpoolctl==3.2.0 + - tifffile==2023.12.9 + - timm==0.9.12 + - tokenizers==0.15.0 + - tomli==2.0.1 + - tomlkit==0.12.0 + - toolz==0.12.0 + - torch==2.0.1+cu118 + - torchaudio==2.0.2+cu118 + - torchdiffeq==0.2.3 + - torchmetrics==1.3.0.post0 + - torchsde==0.2.6 + - torchvision==0.15.2+cu118 + - tqdm==4.66.1 + - traitlets==5.14.1 + - trampoline==0.1.2 + - transformers==4.36.2 + - triton==2.0.0 + - twisted==23.10.0 + - typer==0.9.0 + - typing-extensions==4.9.0 + - typing-inspect==0.9.0 + - tzdata==2023.4 + - urllib3==2.1.0 + - uvicorn==0.26.0 + - wcwidth==0.2.13 + - websockets==11.0.3 + - xformers==0.0.20 + - yarl==1.9.4 + - zipp==3.17.0 + - zope-interface==6.2 + - onnxruntime-gpu==1.13.1 +prefix: /mnt/user/miniconda3/envs/dtrans diff --git a/modeldownloader.py b/modeldownloader.py new file mode 100644 index 0000000..3916220 --- /dev/null +++ b/modeldownloader.py @@ -0,0 +1,2 @@ +from modelscope.hub.snapshot_download import snapshot_download +model_dir = snapshot_download('iic/unianimate', cache_dir='checkpoints/') diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..57b8664 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,208 @@ +aiofiles==23.2.1 +aiohttp==3.9.1 +aiosignal==1.3.1 +aliyun-python-sdk-core==2.14.0 +aliyun-python-sdk-kms==2.16.2 +altair==5.2.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.2.0 +asttokens==2.4.1 +async-timeout==4.0.3 +attrs==23.2.0 +Automat==22.10.0 +beartype==0.16.4 +blessed==1.20.0 +buildtools==1.0.6 +# causal-conv1d==1.1.3.post1 +certifi==2023.11.17 +cffi==1.16.0 +chardet==5.2.0 +charset-normalizer==3.3.2 +clean-fid==0.1.35 +click==8.1.7 +# clip==1.0 +cmake==3.28.1 +colorama==0.4.6 +coloredlogs==15.0.1 +constantly==23.10.4 +contourpy==1.2.0 +crcmod==1.7 +cryptography==41.0.7 +cycler==0.12.1 +decorator==5.1.1 +decord==0.6.0 +diffusers==0.26.3 +docopt==0.6.2 +easydict==1.11 +einops==0.7.0 +exceptiongroup==1.2.0 +executing==2.0.1 +fairscale==0.4.13 +fastapi==0.109.0 +ffmpeg==1.4 +ffmpy==0.3.1 +filelock==3.13.1 +flatbuffers==24.3.25 +fonttools==4.47.2 +frozenlist==1.4.1 +fsspec==2023.12.2 +ftfy==6.1.3 +furl==2.1.3 +gpustat==1.1.1 +gradio==4.14.0 +gradio_client==0.8.0 +greenlet==3.0.3 +h11==0.14.0 +httpcore==1.0.2 +httpx==0.26.0 +huggingface-hub==0.20.2 +humanfriendly==10.0 +hyperlink==21.0.0 +idna==3.6 +imageio==2.33.1 +imageio-ffmpeg==0.4.9 +importlib-metadata==7.0.1 +importlib-resources==6.1.1 +incremental==22.10.0 +ipdb==0.13.13 +ipython==8.18.1 +jedi==0.19.1 +Jinja2==3.1.3 +jmespath==0.10.0 +joblib==1.3.2 +jsonschema==4.21.0 +jsonschema-specifications==2023.12.1 +kiwisolver==1.4.5 +kornia==0.7.1 +lazy_loader==0.3 +lightning-utilities==0.10.0 +lit==17.0.6 +lpips==0.1.4 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.2 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +# motion-vector-extractor==1.0.6 +mpmath==1.3.0 +multidict==6.0.4 +mypy-extensions==1.0.0 +networkx==3.2.1 +ninja==1.11.1.1 +numpy==1.26.3 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu11==11.7.101 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu11==8.5.0.96 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu11==10.9.0.58 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu11==10.2.10.91 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu11==11.4.0.1 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu11==11.7.4.91 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-ml-py==12.535.133 +nvidia-nccl-cu11==2.14.3 +nvidia-nccl-cu12==2.19.3 +nvidia-nvjitlink-cu12==12.3.101 +nvidia-nvtx-cu11==11.7.91 +nvidia-nvtx-cu12==12.1.105 +omegaconf==2.3.0 +onnxruntime==1.18.0 +open-clip-torch==2.24.0 +opencv-python==4.5.3.56 +opencv-python-headless==4.9.0.80 +orderedmultidict==1.0.1 +orjson==3.9.10 +oss2==2.18.4 +# packaging==23.2 +pandas==2.1.4 +parso==0.8.3 +pexpect==4.9.0 +pillow==10.2.0 +piq==0.8.0 +pkgconfig==1.5.5 +prompt-toolkit==3.0.43 +protobuf==4.25.2 +psutil==5.9.8 +ptflops==0.7.2.2 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycparser==2.21 +pycryptodome==3.20.0 +pydantic==2.5.3 +pydantic_core==2.14.6 +pydub==0.25.1 +Pygments==2.17.2 +pynvml==11.5.0 +pyparsing==3.1.1 +pyre-extensions==0.0.29 +python-dateutil==2.8.2 +python-multipart==0.0.6 +pytorch-lightning==2.1.3 +pytz==2023.3.post1 +PyYAML==6.0.1 +redo==2.0.4 +referencing==0.32.1 +regex==2023.12.25 +requests==2.31.0 +rich==13.7.0 +rotary-embedding-torch==0.5.3 +rpds-py==0.17.1 +ruff==0.2.0 +safetensors==0.4.1 +scikit-image==0.22.0 +scikit-learn==1.4.0 +scipy==1.11.4 +semantic-version==2.10.0 +sentencepiece==0.1.99 +shellingham==1.5.4 +simplejson==3.19.2 +six==1.16.0 +sk-video==1.1.10 +sniffio==1.3.0 +SQLAlchemy==2.0.27 +stack-data==0.6.3 +starlette==0.35.1 +sympy==1.12 +thop==0.1.1.post2209072238 +threadpoolctl==3.2.0 +tifffile==2023.12.9 +timm==0.9.12 +tokenizers==0.15.0 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.0 +# torch==2.0.1+cu118 +# torchaudio==2.0.2+cu118 +torchdiffeq==0.2.3 +torchmetrics==1.3.0.post0 +torchsde==0.2.6 +# torchvision==0.15.2+cu118 +tqdm==4.66.1 +traitlets==5.14.1 +trampoline==0.1.2 +transformers==4.36.2 +triton==2.0.0 +Twisted==23.10.0 +typer==0.9.0 +typing-inspect==0.9.0 +typing_extensions==4.9.0 +tzdata==2023.4 +urllib3==2.1.0 +uvicorn==0.26.0 +wcwidth==0.2.13 +websockets==11.0.3 +xformers==0.0.20 +yarl==1.9.4 +zipp==3.17.0 +zope.interface==6.2 +onnxruntime-gpu==1.13.1 diff --git a/run_align_pose.py b/run_align_pose.py new file mode 100644 index 0000000..b0d0c9a --- /dev/null +++ b/run_align_pose.py @@ -0,0 +1,709 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) + +import os +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +import cv2 +import torch +import numpy as np +import json +import copy +import torch +import random +import argparse +import shutil +import tempfile +import subprocess +import numpy as np +import math + +import torch.multiprocessing as mp +import torch.distributed as dist +# import pickle +import logging +# from io import BytesIO +# import oss2 as oss +# import os.path as osp + +import sys +from .dwpose import util +from .dwpose.wholebody import Wholebody + + +def smoothing_factor(t_e, cutoff): + r = 2 * math.pi * cutoff * t_e + return r / (r + 1) + + +def exponential_smoothing(a, x, x_prev): + return a * x + (1 - a) * x_prev + + +class OneEuroFilter: + def __init__(self, t0, x0, dx0=0.0, min_cutoff=1.0, beta=0.0, + d_cutoff=1.0): + """Initialize the one euro filter.""" + # The parameters. + self.min_cutoff = float(min_cutoff) + self.beta = float(beta) + self.d_cutoff = float(d_cutoff) + # Previous values. + self.x_prev = x0 + self.dx_prev = float(dx0) + self.t_prev = float(t0) + + def __call__(self, t, x): + """Compute the filtered signal.""" + t_e = t - self.t_prev + + # The filtered derivative of the signal. + a_d = smoothing_factor(t_e, self.d_cutoff) + dx = (x - self.x_prev) / t_e + dx_hat = exponential_smoothing(a_d, dx, self.dx_prev) + + # The filtered signal. + cutoff = self.min_cutoff + self.beta * abs(dx_hat) + a = smoothing_factor(t_e, cutoff) + x_hat = exponential_smoothing(a, x, self.x_prev) + + # Memorize the previous values. + self.x_prev = x_hat + self.dx_prev = dx_hat + self.t_prev = t + + return x_hat + + +def get_logger(name="essmc2"): + logger = logging.getLogger(name) + logger.propagate = False + if len(logger.handlers) == 0: + std_handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s') + std_handler.setFormatter(formatter) + std_handler.setLevel(logging.INFO) + logger.setLevel(logging.INFO) + logger.addHandler(std_handler) + return logger + +class DWposeDetector: + def __init__(self): + + self.pose_estimation = Wholebody() + + def __call__(self, oriImg): + oriImg = oriImg.copy() + # print(f'The shape of the image should be in HWC format but it is currently: {oriImg.shape} ') + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.pose_estimation(oriImg) + candidate = candidate[0][np.newaxis, :, :] + subset = subset[0][np.newaxis, :] + nums, keys, locs = candidate.shape + candidate[..., 0] /= float(W) + candidate[..., 1] /= float(H) + body = candidate[:,:18].copy() + body = body.reshape(nums*18, locs) + score = subset[:,:18].copy() + + for i in range(len(score)): + for j in range(len(score[i])): + if score[i][j] > 0.3: + score[i][j] = int(18*i+j) + else: + score[i][j] = -1 + + un_visible = subset<0.3 + candidate[un_visible] = -1 + + bodyfoot_score = subset[:,:24].copy() + for i in range(len(bodyfoot_score)): + for j in range(len(bodyfoot_score[i])): + if bodyfoot_score[i][j] > 0.3: + bodyfoot_score[i][j] = int(18*i+j) + else: + bodyfoot_score[i][j] = -1 + if -1 not in bodyfoot_score[:,18] and -1 not in bodyfoot_score[:,19]: + bodyfoot_score[:,18] = np.array([18.]) + else: + bodyfoot_score[:,18] = np.array([-1.]) + if -1 not in bodyfoot_score[:,21] and -1 not in bodyfoot_score[:,22]: + bodyfoot_score[:,19] = np.array([19.]) + else: + bodyfoot_score[:,19] = np.array([-1.]) + bodyfoot_score = bodyfoot_score[:, :20] + + bodyfoot = candidate[:,:24].copy() + + for i in range(nums): + if -1 not in bodyfoot[i][18] and -1 not in bodyfoot[i][19]: + bodyfoot[i][18] = (bodyfoot[i][18]+bodyfoot[i][19])/2 + else: + bodyfoot[i][18] = np.array([-1., -1.]) + if -1 not in bodyfoot[i][21] and -1 not in bodyfoot[i][22]: + bodyfoot[i][19] = (bodyfoot[i][21]+bodyfoot[i][22])/2 + else: + bodyfoot[i][19] = np.array([-1., -1.]) + + bodyfoot = bodyfoot[:,:20,:] + bodyfoot = bodyfoot.reshape(nums*20, locs) + + foot = candidate[:,18:24] + + faces = candidate[:,24:92] + + hands = candidate[:,92:113] + hands = np.vstack([hands, candidate[:,113:]]) + + # bodies = dict(candidate=body, subset=score) + bodies = dict(candidate=bodyfoot, subset=bodyfoot_score) + pose = dict(bodies=bodies, hands=hands, faces=faces) + + # return draw_pose(pose, H, W) + return pose + +def draw_pose(pose, H, W): + bodies = pose['bodies'] + faces = pose['faces'] + hands = pose['hands'] + candidate = bodies['candidate'] + subset = bodies['subset'] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + canvas = util.draw_body_and_foot(canvas, candidate, subset) + + canvas = util.draw_handpose(canvas, hands) + + canvas_without_face = copy.deepcopy(canvas) + + canvas = util.draw_facepose(canvas, faces) + + return canvas_without_face, canvas + +def dw_func(_id, frame, dwpose_model, dwpose_woface_folder='tmp_dwpose_wo_face', dwpose_withface_folder='tmp_dwpose_with_face'): + + # frame = cv2.imread(frame_name, cv2.IMREAD_COLOR) + pose = dwpose_model(frame) + + return pose + + +def mp_main(reference_image, video): + + logger.info(f"There are {video.size(0)} frames for extracting poses") + + logger.info('LOAD: DW Pose Model') + dwpose_model = DWposeDetector() + + results_vis = [] + num_frames = video.size(0) + + + for i in range(num_frames): + logger.info(f"Processing frame {i + 1}/{num_frames}") + frame = video[i].permute(0, 1, 2).cpu().numpy()*255 # Convert to HWC format and numpy array + frame = np.flip(frame, axis=2) + pose = dw_func(i, frame, dwpose_model) + results_vis.append(pose) + + logger.info(f'All frames have been processed.') + print(len(results_vis)) + + # Process the reference image + ref_frame = reference_image.squeeze(0).cpu().numpy()*255 # Convert to HWC format and numpy array + ref_frame = np.flip(ref_frame, axis=2) + pose_ref = dw_func(-1, ref_frame, dwpose_model) + # print(f'The content of the image is currently: {pose_ref} ') + + bodies = results_vis[0]['bodies'] + faces = results_vis[0]['faces'] + hands = results_vis[0]['hands'] + candidate = bodies['candidate'] + + ref_bodies = pose_ref['bodies'] + ref_faces = pose_ref['faces'] + ref_hands = pose_ref['hands'] + ref_candidate = ref_bodies['candidate'] + + + ref_2_x = ref_candidate[2][0] + ref_2_y = ref_candidate[2][1] + ref_5_x = ref_candidate[5][0] + ref_5_y = ref_candidate[5][1] + ref_8_x = ref_candidate[8][0] + ref_8_y = ref_candidate[8][1] + ref_11_x = ref_candidate[11][0] + ref_11_y = ref_candidate[11][1] + ref_center1 = 0.5*(ref_candidate[2]+ref_candidate[5]) + ref_center2 = 0.5*(ref_candidate[8]+ref_candidate[11]) + + zero_2_x = candidate[2][0] + zero_2_y = candidate[2][1] + zero_5_x = candidate[5][0] + zero_5_y = candidate[5][1] + zero_8_x = candidate[8][0] + zero_8_y = candidate[8][1] + zero_11_x = candidate[11][0] + zero_11_y = candidate[11][1] + zero_center1 = 0.5*(candidate[2]+candidate[5]) + zero_center2 = 0.5*(candidate[8]+candidate[11]) + + x_ratio, y_ratio = 1, 1 + if (zero_5_x-zero_2_x) > 0 : + x_ratio = (ref_5_x-ref_2_x)/(zero_5_x-zero_2_x) + if (zero_center2[1]-zero_center1[1]) > 0 : + y_ratio = (ref_center2[1]-ref_center1[1])/(zero_center2[1]-zero_center1[1]) + + results_vis[0]['bodies']['candidate'][:,0] *= x_ratio + results_vis[0]['bodies']['candidate'][:,1] *= y_ratio + results_vis[0]['faces'][:,:,0] *= x_ratio + results_vis[0]['faces'][:,:,1] *= y_ratio + results_vis[0]['hands'][:,:,0] *= x_ratio + results_vis[0]['hands'][:,:,1] *= y_ratio + + ########neck######## + neck_ratio = 1 + l_neck_ref = ((ref_candidate[0][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[0][1] - ref_candidate[1][1]) ** 2) ** 0.5 + l_neck_0 = ((candidate[0][0] - candidate[1][0]) ** 2 + (candidate[0][1] - candidate[1][1]) ** 2) ** 0.5 + if l_neck_0 != 0: + neck_ratio = l_neck_ref / l_neck_0 + + x_offset_neck = (candidate[1][0]-candidate[0][0])*(1.-neck_ratio) + y_offset_neck = (candidate[1][1]-candidate[0][1])*(1.-neck_ratio) + + results_vis[0]['bodies']['candidate'][0,0] += x_offset_neck + results_vis[0]['bodies']['candidate'][0,1] += y_offset_neck + results_vis[0]['bodies']['candidate'][14,0] += x_offset_neck + results_vis[0]['bodies']['candidate'][14,1] += y_offset_neck + results_vis[0]['bodies']['candidate'][15,0] += x_offset_neck + results_vis[0]['bodies']['candidate'][15,1] += y_offset_neck + results_vis[0]['bodies']['candidate'][16,0] += x_offset_neck + results_vis[0]['bodies']['candidate'][16,1] += y_offset_neck + results_vis[0]['bodies']['candidate'][17,0] += x_offset_neck + results_vis[0]['bodies']['candidate'][17,1] += y_offset_neck + + ########shoulder2######## + shoulder2_ratio = 1 + l_shoulder2_ref = ((ref_candidate[2][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[2][1] - ref_candidate[1][1]) ** 2) ** 0.5 + l_shoulder2_0 = ((candidate[2][0] - candidate[1][0]) ** 2 + (candidate[2][1] - candidate[1][1]) ** 2) ** 0.5 + if l_shoulder2_0 != 0: + shoulder2_ratio = l_shoulder2_ref / l_shoulder2_0 + + x_offset_shoulder2 = (candidate[1][0]-candidate[2][0])*(1.-shoulder2_ratio) + y_offset_shoulder2 = (candidate[1][1]-candidate[2][1])*(1.-shoulder2_ratio) + + results_vis[0]['bodies']['candidate'][2,0] += x_offset_shoulder2 + results_vis[0]['bodies']['candidate'][2,1] += y_offset_shoulder2 + results_vis[0]['bodies']['candidate'][3,0] += x_offset_shoulder2 + results_vis[0]['bodies']['candidate'][3,1] += y_offset_shoulder2 + results_vis[0]['bodies']['candidate'][4,0] += x_offset_shoulder2 + results_vis[0]['bodies']['candidate'][4,1] += y_offset_shoulder2 + results_vis[0]['hands'][1,:,0] += x_offset_shoulder2 + results_vis[0]['hands'][1,:,1] += y_offset_shoulder2 + + ########shoulder5######## + shoulder5_ratio = 1 + l_shoulder5_ref = ((ref_candidate[5][0] - ref_candidate[1][0]) ** 2 + (ref_candidate[5][1] - ref_candidate[1][1]) ** 2) ** 0.5 + l_shoulder5_0 = ((candidate[5][0] - candidate[1][0]) ** 2 + (candidate[5][1] - candidate[1][1]) ** 2) ** 0.5 + if l_shoulder5_0 != 0: + shoulder5_ratio = l_shoulder5_ref / l_shoulder5_0 + + x_offset_shoulder5 = (candidate[1][0]-candidate[5][0])*(1.-shoulder5_ratio) + y_offset_shoulder5 = (candidate[1][1]-candidate[5][1])*(1.-shoulder5_ratio) + + results_vis[0]['bodies']['candidate'][5,0] += x_offset_shoulder5 + results_vis[0]['bodies']['candidate'][5,1] += y_offset_shoulder5 + results_vis[0]['bodies']['candidate'][6,0] += x_offset_shoulder5 + results_vis[0]['bodies']['candidate'][6,1] += y_offset_shoulder5 + results_vis[0]['bodies']['candidate'][7,0] += x_offset_shoulder5 + results_vis[0]['bodies']['candidate'][7,1] += y_offset_shoulder5 + results_vis[0]['hands'][0,:,0] += x_offset_shoulder5 + results_vis[0]['hands'][0,:,1] += y_offset_shoulder5 + + ########arm3######## + arm3_ratio = 1 + l_arm3_ref = ((ref_candidate[3][0] - ref_candidate[2][0]) ** 2 + (ref_candidate[3][1] - ref_candidate[2][1]) ** 2) ** 0.5 + l_arm3_0 = ((candidate[3][0] - candidate[2][0]) ** 2 + (candidate[3][1] - candidate[2][1]) ** 2) ** 0.5 + if l_arm3_0 != 0: + arm3_ratio = l_arm3_ref / l_arm3_0 + + x_offset_arm3 = (candidate[2][0]-candidate[3][0])*(1.-arm3_ratio) + y_offset_arm3 = (candidate[2][1]-candidate[3][1])*(1.-arm3_ratio) + + results_vis[0]['bodies']['candidate'][3,0] += x_offset_arm3 + results_vis[0]['bodies']['candidate'][3,1] += y_offset_arm3 + results_vis[0]['bodies']['candidate'][4,0] += x_offset_arm3 + results_vis[0]['bodies']['candidate'][4,1] += y_offset_arm3 + results_vis[0]['hands'][1,:,0] += x_offset_arm3 + results_vis[0]['hands'][1,:,1] += y_offset_arm3 + + ########arm4######## + arm4_ratio = 1 + l_arm4_ref = ((ref_candidate[4][0] - ref_candidate[3][0]) ** 2 + (ref_candidate[4][1] - ref_candidate[3][1]) ** 2) ** 0.5 + l_arm4_0 = ((candidate[4][0] - candidate[3][0]) ** 2 + (candidate[4][1] - candidate[3][1]) ** 2) ** 0.5 + if l_arm4_0 != 0: + arm4_ratio = l_arm4_ref / l_arm4_0 + + x_offset_arm4 = (candidate[3][0]-candidate[4][0])*(1.-arm4_ratio) + y_offset_arm4 = (candidate[3][1]-candidate[4][1])*(1.-arm4_ratio) + + results_vis[0]['bodies']['candidate'][4,0] += x_offset_arm4 + results_vis[0]['bodies']['candidate'][4,1] += y_offset_arm4 + results_vis[0]['hands'][1,:,0] += x_offset_arm4 + results_vis[0]['hands'][1,:,1] += y_offset_arm4 + + ########arm6######## + arm6_ratio = 1 + l_arm6_ref = ((ref_candidate[6][0] - ref_candidate[5][0]) ** 2 + (ref_candidate[6][1] - ref_candidate[5][1]) ** 2) ** 0.5 + l_arm6_0 = ((candidate[6][0] - candidate[5][0]) ** 2 + (candidate[6][1] - candidate[5][1]) ** 2) ** 0.5 + if l_arm6_0 != 0: + arm6_ratio = l_arm6_ref / l_arm6_0 + + x_offset_arm6 = (candidate[5][0]-candidate[6][0])*(1.-arm6_ratio) + y_offset_arm6 = (candidate[5][1]-candidate[6][1])*(1.-arm6_ratio) + + results_vis[0]['bodies']['candidate'][6,0] += x_offset_arm6 + results_vis[0]['bodies']['candidate'][6,1] += y_offset_arm6 + results_vis[0]['bodies']['candidate'][7,0] += x_offset_arm6 + results_vis[0]['bodies']['candidate'][7,1] += y_offset_arm6 + results_vis[0]['hands'][0,:,0] += x_offset_arm6 + results_vis[0]['hands'][0,:,1] += y_offset_arm6 + + ########arm7######## + arm7_ratio = 1 + l_arm7_ref = ((ref_candidate[7][0] - ref_candidate[6][0]) ** 2 + (ref_candidate[7][1] - ref_candidate[6][1]) ** 2) ** 0.5 + l_arm7_0 = ((candidate[7][0] - candidate[6][0]) ** 2 + (candidate[7][1] - candidate[6][1]) ** 2) ** 0.5 + if l_arm7_0 != 0: + arm7_ratio = l_arm7_ref / l_arm7_0 + + x_offset_arm7 = (candidate[6][0]-candidate[7][0])*(1.-arm7_ratio) + y_offset_arm7 = (candidate[6][1]-candidate[7][1])*(1.-arm7_ratio) + + results_vis[0]['bodies']['candidate'][7,0] += x_offset_arm7 + results_vis[0]['bodies']['candidate'][7,1] += y_offset_arm7 + results_vis[0]['hands'][0,:,0] += x_offset_arm7 + results_vis[0]['hands'][0,:,1] += y_offset_arm7 + + ########head14######## + l_head14_ref = ((ref_candidate[14][0] - ref_candidate[0][0]) ** 2 + (ref_candidate[14][1] - ref_candidate[0][1]) ** 2) ** 0.5 + l_head14_0 = ((candidate[14][0] - candidate[0][0]) ** 2 + (candidate[14][1] - candidate[0][1]) ** 2) ** 0.5 + + head14_ratio = l_head14_ref / l_head14_0 + + x_offset_head14 = (candidate[0][0]-candidate[14][0])*(1.-head14_ratio) + y_offset_head14 = (candidate[0][1]-candidate[14][1])*(1.-head14_ratio) + + results_vis[0]['bodies']['candidate'][14,0] += x_offset_head14 + results_vis[0]['bodies']['candidate'][14,1] += y_offset_head14 + results_vis[0]['bodies']['candidate'][16,0] += x_offset_head14 + results_vis[0]['bodies']['candidate'][16,1] += y_offset_head14 + + ########head15######## + l_head15_ref = ((ref_candidate[15][0] - ref_candidate[0][0]) ** 2 + (ref_candidate[15][1] - ref_candidate[0][1]) ** 2) ** 0.5 + l_head15_0 = ((candidate[15][0] - candidate[0][0]) ** 2 + (candidate[15][1] - candidate[0][1]) ** 2) ** 0.5 + + head15_ratio = l_head15_ref / l_head15_0 + + x_offset_head15 = (candidate[0][0]-candidate[15][0])*(1.-head15_ratio) + y_offset_head15 = (candidate[0][1]-candidate[15][1])*(1.-head15_ratio) + + results_vis[0]['bodies']['candidate'][15,0] += x_offset_head15 + results_vis[0]['bodies']['candidate'][15,1] += y_offset_head15 + results_vis[0]['bodies']['candidate'][17,0] += x_offset_head15 + results_vis[0]['bodies']['candidate'][17,1] += y_offset_head15 + + ########head16######## + l_head16_ref = ((ref_candidate[16][0] - ref_candidate[14][0]) ** 2 + (ref_candidate[16][1] - ref_candidate[14][1]) ** 2) ** 0.5 + l_head16_0 = ((candidate[16][0] - candidate[14][0]) ** 2 + (candidate[16][1] - candidate[14][1]) ** 2) ** 0.5 + + head16_ratio = l_head16_ref / l_head16_0 + + x_offset_head16 = (candidate[14][0]-candidate[16][0])*(1.-head16_ratio) + y_offset_head16 = (candidate[14][1]-candidate[16][1])*(1.-head16_ratio) + + results_vis[0]['bodies']['candidate'][16,0] += x_offset_head16 + results_vis[0]['bodies']['candidate'][16,1] += y_offset_head16 + + ########head17######## + l_head17_ref = ((ref_candidate[17][0] - ref_candidate[15][0]) ** 2 + (ref_candidate[17][1] - ref_candidate[15][1]) ** 2) ** 0.5 + l_head17_0 = ((candidate[17][0] - candidate[15][0]) ** 2 + (candidate[17][1] - candidate[15][1]) ** 2) ** 0.5 + + head17_ratio = l_head17_ref / l_head17_0 + + x_offset_head17 = (candidate[15][0]-candidate[17][0])*(1.-head17_ratio) + y_offset_head17 = (candidate[15][1]-candidate[17][1])*(1.-head17_ratio) + + results_vis[0]['bodies']['candidate'][17,0] += x_offset_head17 + results_vis[0]['bodies']['candidate'][17,1] += y_offset_head17 + + ########MovingAverage######## + + ########left leg######## + ll1_ratio = 1 + l_ll1_ref = ((ref_candidate[8][0] - ref_candidate[9][0]) ** 2 + (ref_candidate[8][1] - ref_candidate[9][1]) ** 2) ** 0.5 + l_ll1_0 = ((candidate[8][0] - candidate[9][0]) ** 2 + (candidate[8][1] - candidate[9][1]) ** 2) ** 0.5 + if l_ll1_0 != 0 : + ll1_ratio = l_ll1_ref / l_ll1_0 + + x_offset_ll1 = (candidate[9][0]-candidate[8][0])*(ll1_ratio-1.) + y_offset_ll1 = (candidate[9][1]-candidate[8][1])*(ll1_ratio-1.) + + results_vis[0]['bodies']['candidate'][9,0] += x_offset_ll1 + results_vis[0]['bodies']['candidate'][9,1] += y_offset_ll1 + results_vis[0]['bodies']['candidate'][10,0] += x_offset_ll1 + results_vis[0]['bodies']['candidate'][10,1] += y_offset_ll1 + results_vis[0]['bodies']['candidate'][19,0] += x_offset_ll1 + results_vis[0]['bodies']['candidate'][19,1] += y_offset_ll1 + + l_ll2_ref = ((ref_candidate[9][0] - ref_candidate[10][0]) ** 2 + (ref_candidate[9][1] - ref_candidate[10][1]) ** 2) ** 0.5 + l_ll2_0 = ((candidate[9][0] - candidate[10][0]) ** 2 + (candidate[9][1] - candidate[10][1]) ** 2) ** 0.5 + ll2_ratio = l_ll2_ref / l_ll2_0 + + x_offset_ll2 = (candidate[10][0]-candidate[9][0])*(ll2_ratio-1.) + y_offset_ll2 = (candidate[10][1]-candidate[9][1])*(ll2_ratio-1.) + + results_vis[0]['bodies']['candidate'][10,0] += x_offset_ll2 + results_vis[0]['bodies']['candidate'][10,1] += y_offset_ll2 + results_vis[0]['bodies']['candidate'][19,0] += x_offset_ll2 + results_vis[0]['bodies']['candidate'][19,1] += y_offset_ll2 + + ########right leg######## + rl1_ratio = 1 + l_rl1_ref = ((ref_candidate[11][0] - ref_candidate[12][0]) ** 2 + (ref_candidate[11][1] - ref_candidate[12][1]) ** 2) ** 0.5 + l_rl1_0 = ((candidate[11][0] - candidate[12][0]) ** 2 + (candidate[11][1] - candidate[12][1]) ** 2) ** 0.5 + if l_rl1_0 != 0: + rl1_ratio = l_rl1_ref / l_rl1_0 + + x_offset_rl1 = (candidate[12][0]-candidate[11][0])*(rl1_ratio-1.) + y_offset_rl1 = (candidate[12][1]-candidate[11][1])*(rl1_ratio-1.) + + results_vis[0]['bodies']['candidate'][12,0] += x_offset_rl1 + results_vis[0]['bodies']['candidate'][12,1] += y_offset_rl1 + results_vis[0]['bodies']['candidate'][13,0] += x_offset_rl1 + results_vis[0]['bodies']['candidate'][13,1] += y_offset_rl1 + results_vis[0]['bodies']['candidate'][18,0] += x_offset_rl1 + results_vis[0]['bodies']['candidate'][18,1] += y_offset_rl1 + + l_rl2_ref = ((ref_candidate[12][0] - ref_candidate[13][0]) ** 2 + (ref_candidate[12][1] - ref_candidate[13][1]) ** 2) ** 0.5 + l_rl2_0 = ((candidate[12][0] - candidate[13][0]) ** 2 + (candidate[12][1] - candidate[13][1]) ** 2) ** 0.5 + rl2_ratio = l_rl2_ref / l_rl2_0 + + x_offset_rl2 = (candidate[13][0]-candidate[12][0])*(rl2_ratio-1.) + y_offset_rl2 = (candidate[13][1]-candidate[12][1])*(rl2_ratio-1.) + + results_vis[0]['bodies']['candidate'][13,0] += x_offset_rl2 + results_vis[0]['bodies']['candidate'][13,1] += y_offset_rl2 + results_vis[0]['bodies']['candidate'][18,0] += x_offset_rl2 + results_vis[0]['bodies']['candidate'][18,1] += y_offset_rl2 + + offset = ref_candidate[1] - results_vis[0]['bodies']['candidate'][1] + + results_vis[0]['bodies']['candidate'] += offset[np.newaxis, :] + results_vis[0]['faces'] += offset[np.newaxis, np.newaxis, :] + results_vis[0]['hands'] += offset[np.newaxis, np.newaxis, :] + + for i in range(1, len(results_vis)): + results_vis[i]['bodies']['candidate'][:,0] *= x_ratio + results_vis[i]['bodies']['candidate'][:,1] *= y_ratio + results_vis[i]['faces'][:,:,0] *= x_ratio + results_vis[i]['faces'][:,:,1] *= y_ratio + results_vis[i]['hands'][:,:,0] *= x_ratio + results_vis[i]['hands'][:,:,1] *= y_ratio + + ########neck######## + x_offset_neck = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][0][0])*(1.-neck_ratio) + y_offset_neck = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][0][1])*(1.-neck_ratio) + + results_vis[i]['bodies']['candidate'][0,0] += x_offset_neck + results_vis[i]['bodies']['candidate'][0,1] += y_offset_neck + results_vis[i]['bodies']['candidate'][14,0] += x_offset_neck + results_vis[i]['bodies']['candidate'][14,1] += y_offset_neck + results_vis[i]['bodies']['candidate'][15,0] += x_offset_neck + results_vis[i]['bodies']['candidate'][15,1] += y_offset_neck + results_vis[i]['bodies']['candidate'][16,0] += x_offset_neck + results_vis[i]['bodies']['candidate'][16,1] += y_offset_neck + results_vis[i]['bodies']['candidate'][17,0] += x_offset_neck + results_vis[i]['bodies']['candidate'][17,1] += y_offset_neck + + ########shoulder2######## + + + x_offset_shoulder2 = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][2][0])*(1.-shoulder2_ratio) + y_offset_shoulder2 = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][2][1])*(1.-shoulder2_ratio) + + results_vis[i]['bodies']['candidate'][2,0] += x_offset_shoulder2 + results_vis[i]['bodies']['candidate'][2,1] += y_offset_shoulder2 + results_vis[i]['bodies']['candidate'][3,0] += x_offset_shoulder2 + results_vis[i]['bodies']['candidate'][3,1] += y_offset_shoulder2 + results_vis[i]['bodies']['candidate'][4,0] += x_offset_shoulder2 + results_vis[i]['bodies']['candidate'][4,1] += y_offset_shoulder2 + results_vis[i]['hands'][1,:,0] += x_offset_shoulder2 + results_vis[i]['hands'][1,:,1] += y_offset_shoulder2 + + ########shoulder5######## + + x_offset_shoulder5 = (results_vis[i]['bodies']['candidate'][1][0]-results_vis[i]['bodies']['candidate'][5][0])*(1.-shoulder5_ratio) + y_offset_shoulder5 = (results_vis[i]['bodies']['candidate'][1][1]-results_vis[i]['bodies']['candidate'][5][1])*(1.-shoulder5_ratio) + + results_vis[i]['bodies']['candidate'][5,0] += x_offset_shoulder5 + results_vis[i]['bodies']['candidate'][5,1] += y_offset_shoulder5 + results_vis[i]['bodies']['candidate'][6,0] += x_offset_shoulder5 + results_vis[i]['bodies']['candidate'][6,1] += y_offset_shoulder5 + results_vis[i]['bodies']['candidate'][7,0] += x_offset_shoulder5 + results_vis[i]['bodies']['candidate'][7,1] += y_offset_shoulder5 + results_vis[i]['hands'][0,:,0] += x_offset_shoulder5 + results_vis[i]['hands'][0,:,1] += y_offset_shoulder5 + + ########arm3######## + + x_offset_arm3 = (results_vis[i]['bodies']['candidate'][2][0]-results_vis[i]['bodies']['candidate'][3][0])*(1.-arm3_ratio) + y_offset_arm3 = (results_vis[i]['bodies']['candidate'][2][1]-results_vis[i]['bodies']['candidate'][3][1])*(1.-arm3_ratio) + + results_vis[i]['bodies']['candidate'][3,0] += x_offset_arm3 + results_vis[i]['bodies']['candidate'][3,1] += y_offset_arm3 + results_vis[i]['bodies']['candidate'][4,0] += x_offset_arm3 + results_vis[i]['bodies']['candidate'][4,1] += y_offset_arm3 + results_vis[i]['hands'][1,:,0] += x_offset_arm3 + results_vis[i]['hands'][1,:,1] += y_offset_arm3 + + ########arm4######## + + x_offset_arm4 = (results_vis[i]['bodies']['candidate'][3][0]-results_vis[i]['bodies']['candidate'][4][0])*(1.-arm4_ratio) + y_offset_arm4 = (results_vis[i]['bodies']['candidate'][3][1]-results_vis[i]['bodies']['candidate'][4][1])*(1.-arm4_ratio) + + results_vis[i]['bodies']['candidate'][4,0] += x_offset_arm4 + results_vis[i]['bodies']['candidate'][4,1] += y_offset_arm4 + results_vis[i]['hands'][1,:,0] += x_offset_arm4 + results_vis[i]['hands'][1,:,1] += y_offset_arm4 + + ########arm6######## + + x_offset_arm6 = (results_vis[i]['bodies']['candidate'][5][0]-results_vis[i]['bodies']['candidate'][6][0])*(1.-arm6_ratio) + y_offset_arm6 = (results_vis[i]['bodies']['candidate'][5][1]-results_vis[i]['bodies']['candidate'][6][1])*(1.-arm6_ratio) + + results_vis[i]['bodies']['candidate'][6,0] += x_offset_arm6 + results_vis[i]['bodies']['candidate'][6,1] += y_offset_arm6 + results_vis[i]['bodies']['candidate'][7,0] += x_offset_arm6 + results_vis[i]['bodies']['candidate'][7,1] += y_offset_arm6 + results_vis[i]['hands'][0,:,0] += x_offset_arm6 + results_vis[i]['hands'][0,:,1] += y_offset_arm6 + + ########arm7######## + + x_offset_arm7 = (results_vis[i]['bodies']['candidate'][6][0]-results_vis[i]['bodies']['candidate'][7][0])*(1.-arm7_ratio) + y_offset_arm7 = (results_vis[i]['bodies']['candidate'][6][1]-results_vis[i]['bodies']['candidate'][7][1])*(1.-arm7_ratio) + + results_vis[i]['bodies']['candidate'][7,0] += x_offset_arm7 + results_vis[i]['bodies']['candidate'][7,1] += y_offset_arm7 + results_vis[i]['hands'][0,:,0] += x_offset_arm7 + results_vis[i]['hands'][0,:,1] += y_offset_arm7 + + ########head14######## + + x_offset_head14 = (results_vis[i]['bodies']['candidate'][0][0]-results_vis[i]['bodies']['candidate'][14][0])*(1.-head14_ratio) + y_offset_head14 = (results_vis[i]['bodies']['candidate'][0][1]-results_vis[i]['bodies']['candidate'][14][1])*(1.-head14_ratio) + + results_vis[i]['bodies']['candidate'][14,0] += x_offset_head14 + results_vis[i]['bodies']['candidate'][14,1] += y_offset_head14 + results_vis[i]['bodies']['candidate'][16,0] += x_offset_head14 + results_vis[i]['bodies']['candidate'][16,1] += y_offset_head14 + + ########head15######## + + x_offset_head15 = (results_vis[i]['bodies']['candidate'][0][0]-results_vis[i]['bodies']['candidate'][15][0])*(1.-head15_ratio) + y_offset_head15 = (results_vis[i]['bodies']['candidate'][0][1]-results_vis[i]['bodies']['candidate'][15][1])*(1.-head15_ratio) + + results_vis[i]['bodies']['candidate'][15,0] += x_offset_head15 + results_vis[i]['bodies']['candidate'][15,1] += y_offset_head15 + results_vis[i]['bodies']['candidate'][17,0] += x_offset_head15 + results_vis[i]['bodies']['candidate'][17,1] += y_offset_head15 + + ########head16######## + + x_offset_head16 = (results_vis[i]['bodies']['candidate'][14][0]-results_vis[i]['bodies']['candidate'][16][0])*(1.-head16_ratio) + y_offset_head16 = (results_vis[i]['bodies']['candidate'][14][1]-results_vis[i]['bodies']['candidate'][16][1])*(1.-head16_ratio) + + results_vis[i]['bodies']['candidate'][16,0] += x_offset_head16 + results_vis[i]['bodies']['candidate'][16,1] += y_offset_head16 + + ########head17######## + x_offset_head17 = (results_vis[i]['bodies']['candidate'][15][0]-results_vis[i]['bodies']['candidate'][17][0])*(1.-head17_ratio) + y_offset_head17 = (results_vis[i]['bodies']['candidate'][15][1]-results_vis[i]['bodies']['candidate'][17][1])*(1.-head17_ratio) + + results_vis[i]['bodies']['candidate'][17,0] += x_offset_head17 + results_vis[i]['bodies']['candidate'][17,1] += y_offset_head17 + + # ########MovingAverage######## + + ########left leg######## + x_offset_ll1 = (results_vis[i]['bodies']['candidate'][9][0]-results_vis[i]['bodies']['candidate'][8][0])*(ll1_ratio-1.) + y_offset_ll1 = (results_vis[i]['bodies']['candidate'][9][1]-results_vis[i]['bodies']['candidate'][8][1])*(ll1_ratio-1.) + + results_vis[i]['bodies']['candidate'][9,0] += x_offset_ll1 + results_vis[i]['bodies']['candidate'][9,1] += y_offset_ll1 + results_vis[i]['bodies']['candidate'][10,0] += x_offset_ll1 + results_vis[i]['bodies']['candidate'][10,1] += y_offset_ll1 + results_vis[i]['bodies']['candidate'][19,0] += x_offset_ll1 + results_vis[i]['bodies']['candidate'][19,1] += y_offset_ll1 + + + + x_offset_ll2 = (results_vis[i]['bodies']['candidate'][10][0]-results_vis[i]['bodies']['candidate'][9][0])*(ll2_ratio-1.) + y_offset_ll2 = (results_vis[i]['bodies']['candidate'][10][1]-results_vis[i]['bodies']['candidate'][9][1])*(ll2_ratio-1.) + + results_vis[i]['bodies']['candidate'][10,0] += x_offset_ll2 + results_vis[i]['bodies']['candidate'][10,1] += y_offset_ll2 + results_vis[i]['bodies']['candidate'][19,0] += x_offset_ll2 + results_vis[i]['bodies']['candidate'][19,1] += y_offset_ll2 + + ########right leg######## + + x_offset_rl1 = (results_vis[i]['bodies']['candidate'][12][0]-results_vis[i]['bodies']['candidate'][11][0])*(rl1_ratio-1.) + y_offset_rl1 = (results_vis[i]['bodies']['candidate'][12][1]-results_vis[i]['bodies']['candidate'][11][1])*(rl1_ratio-1.) + + results_vis[i]['bodies']['candidate'][12,0] += x_offset_rl1 + results_vis[i]['bodies']['candidate'][12,1] += y_offset_rl1 + results_vis[i]['bodies']['candidate'][13,0] += x_offset_rl1 + results_vis[i]['bodies']['candidate'][13,1] += y_offset_rl1 + results_vis[i]['bodies']['candidate'][18,0] += x_offset_rl1 + results_vis[i]['bodies']['candidate'][18,1] += y_offset_rl1 + + + x_offset_rl2 = (results_vis[i]['bodies']['candidate'][13][0]-results_vis[i]['bodies']['candidate'][12][0])*(rl2_ratio-1.) + y_offset_rl2 = (results_vis[i]['bodies']['candidate'][13][1]-results_vis[i]['bodies']['candidate'][12][1])*(rl2_ratio-1.) + + results_vis[i]['bodies']['candidate'][13,0] += x_offset_rl2 + results_vis[i]['bodies']['candidate'][13,1] += y_offset_rl2 + results_vis[i]['bodies']['candidate'][18,0] += x_offset_rl2 + results_vis[i]['bodies']['candidate'][18,1] += y_offset_rl2 + + results_vis[i]['bodies']['candidate'] += offset[np.newaxis, :] + results_vis[i]['faces'] += offset[np.newaxis, np.newaxis, :] + results_vis[i]['hands'] += offset[np.newaxis, np.newaxis, :] + + # Prepare to return the dwpose images + dwpose_images = [] + + for i in range(len(results_vis)): + dwpose_woface, _ = draw_pose(results_vis[i], H=768, W=512) + dwpose_tensor = torch.from_numpy(dwpose_woface).permute(2, 0, 1).unsqueeze(0).float() # Convert to tensor and CHW format + dwpose_tensor = dwpose_tensor.permute(0, 2, 3, 1) + dwpose_images.append(dwpose_tensor) + + dwpose_images = torch.cat(dwpose_images, dim=0) + dwpose_woface_ref, _ = draw_pose(pose_ref, H=768, W=512) + dwpose_ref_tensor = torch.from_numpy(dwpose_woface_ref).permute(2, 0, 1).unsqueeze(0).float() # Convert to tensor and CHW format + dwpose_ref_tensor = dwpose_ref_tensor.permute(0, 2, 3, 1) + + # print(f'The type of the pose from run_align_pose is currently of the form : {type(dwpose_ref_tensor)} ') + # print(f'The content of the pose from run_align_pose is currently: {dwpose_ref_tensor} ') + + return dwpose_images, dwpose_ref_tensor + + +logger = get_logger('dw pose extraction') + diff --git a/uniAnimate_Inference.py b/uniAnimate_Inference.py new file mode 100644 index 0000000..db10817 --- /dev/null +++ b/uniAnimate_Inference.py @@ -0,0 +1,83 @@ +import torch + +from .utils.config import Config +from .tools.inferences import inference_unianimate_entrance +from . import run_align_pose + +# from tools import * + + + +class UniAnimateImage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "steps": ("INT", {"default": 30, "min": 25, "max": 50, "step": 1}), + "useFirstFrame": ("BOOLEAN", { "default": False }), + "reference_image": ("IMAGE",), # single image + "ref_pose": ("IMAGE",), # single image + "pose_sequence": ("IMAGE",), # Batch of pose images + "frame_interval": ("INT", {"default": 1, "min": 1, "max": 8, "step": 1}), + "max_frames": ("INT", {"default": 32, "min": 1, "max": 64, "step": 1}), + "resolution_x": ("INT", {"default": 512, "min": 512, "max": 768, "step": 256}), + } + } + + RETURN_TYPES = ("IMAGE", "MASK") + FUNCTION = "process" + CATEGORY = "image" + + def process(self, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution_x): + cfg_update = Config(load=True) + resolution_y = 768 + if resolution_x == 768: + resolution_y = 1216 + resolution = [resolution_x, resolution_y] + print("Ready for inference.") + + # print(f"image is: {reference_image}") + + frames = inference_unianimate_entrance(steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg_update=cfg_update.cfg_dict) + mask_template = torch.zeros((1, resolution_y, resolution_x), dtype=torch.float32) + masks = [mask_template.clone() for _ in range(len(pose_sequence))] + masks = torch.cat(masks, dim=0) + return (frames, masks) + +class Gen_align_pose: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "reference_image": ("IMAGE",), # single image + "video": ("IMAGE",), # video + } + } + + RETURN_TYPES = ("IMAGE", "IMAGE") + FUNCTION = "process" + CATEGORY = "image" + + def process(self, reference_image, video): + if torch.cuda.is_available(): + print(f"CUDA version: {torch.version.cuda}") + print(f"CUDNN version: {torch.backends.cudnn.version()}") + print(f"Device name: {torch.cuda.get_device_name(0)}") + else: + print("CUDA is not available") + poses, refPose = run_align_pose.mp_main(reference_image, video) + return (refPose, poses) + + + +NODE_CLASS_MAPPINGS = { + "UniAnimateImage" : UniAnimateImage, + "Gen_align_pose" : Gen_align_pose, + +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "UniAnimateImage" :"Animate image with UniAnimate", + "Gen_align_pose" :"Align & Generate poses for UniAnimate", + +} \ No newline at end of file From b390cc6e22eeed290abc69216458772b25548931 Mon Sep 17 00:00:00 2001 From: Isi <86603298+Isi-dev@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:32:38 +0100 Subject: [PATCH 2/5] Add files via upload --- __pycache__/__init__.cpython-310.pyc | Bin 0 -> 314 bytes __pycache__/run_align_pose.cpython-310.pyc | Bin 0 -> 18973 bytes .../uniAnimate_Inference.cpython-310.pyc | Bin 0 -> 2625 bytes configs/UniAnimate_infer.yaml | 101 +++ configs/UniAnimate_infer_long.yaml | 101 +++ dwpose/__init__.py | 0 dwpose/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 191 bytes dwpose/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 135 bytes dwpose/__pycache__/onnxdet.cpython-310.pyc | Bin 0 -> 4220 bytes dwpose/__pycache__/onnxdet.cpython-39.pyc | Bin 0 -> 4176 bytes dwpose/__pycache__/onnxpose.cpython-310.pyc | Bin 0 -> 10311 bytes dwpose/__pycache__/onnxpose.cpython-39.pyc | Bin 0 -> 10234 bytes dwpose/__pycache__/util.cpython-310.pyc | Bin 0 -> 8930 bytes dwpose/__pycache__/util.cpython-39.pyc | Bin 0 -> 8894 bytes dwpose/__pycache__/wholebody.cpython-310.pyc | Bin 0 -> 1996 bytes dwpose/__pycache__/wholebody.cpython-39.pyc | Bin 0 -> 1718 bytes dwpose/onnxdet.py | 127 +++ dwpose/onnxpose.py | 360 +++++++++ dwpose/util.py | 336 ++++++++ dwpose/wholebody.py | 58 ++ lib/rotary_embedding_torch/__init__.py | 6 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 376 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 349 bytes .../rotary_embedding_torch.cpython-310.pyc | Bin 0 -> 7479 bytes .../rotary_embedding_torch.cpython-39.pyc | Bin 0 -> 7411 bytes .../rotary_embedding_torch.py | 291 +++++++ lib/simplejson/__init__.py | 562 +++++++++++++ lib/simplejson/_speedups.cp39-win_amd64.pyd | Bin 0 -> 39936 bytes lib/simplejson/compat.py | 34 + lib/simplejson/decoder.py | 416 ++++++++++ lib/simplejson/encoder.py | 740 ++++++++++++++++++ lib/simplejson/errors.py | 53 ++ lib/simplejson/ordered_dict.py | 103 +++ lib/simplejson/raw_json.py | 9 + lib/simplejson/scanner.py | 85 ++ lib/simplejson/tests/__init__.py | 91 +++ .../tests/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 2925 bytes .../__pycache__/_cibw_runner.cpython-39.pyc | Bin 0 -> 396 bytes .../test_bigint_as_string.cpython-39.pyc | Bin 0 -> 1849 bytes .../test_bitsize_int_as_string.cpython-39.pyc | Bin 0 -> 2229 bytes .../test_check_circular.cpython-39.pyc | Bin 0 -> 1669 bytes .../__pycache__/test_decimal.cpython-39.pyc | Bin 0 -> 2712 bytes .../__pycache__/test_decode.cpython-39.pyc | Bin 0 -> 5916 bytes .../__pycache__/test_default.cpython-39.pyc | Bin 0 -> 623 bytes .../__pycache__/test_dump.cpython-39.pyc | Bin 0 -> 8687 bytes ...est_encode_basestring_ascii.cpython-39.pyc | Bin 0 -> 2331 bytes .../test_encode_for_html.cpython-39.pyc | Bin 0 -> 1866 bytes .../__pycache__/test_errors.cpython-39.pyc | Bin 0 -> 2407 bytes .../__pycache__/test_fail.cpython-39.pyc | Bin 0 -> 3704 bytes .../__pycache__/test_float.cpython-39.pyc | Bin 0 -> 2106 bytes .../__pycache__/test_for_json.cpython-39.pyc | Bin 0 -> 4658 bytes .../__pycache__/test_indent.cpython-39.pyc | Bin 0 -> 2303 bytes .../test_item_sort_key.cpython-39.pyc | Bin 0 -> 2087 bytes .../__pycache__/test_iterable.cpython-39.pyc | Bin 0 -> 1424 bytes .../test_namedtuple.cpython-39.pyc | Bin 0 -> 6216 bytes .../__pycache__/test_pass1.cpython-39.pyc | Bin 0 -> 2046 bytes .../__pycache__/test_pass2.cpython-39.pyc | Bin 0 -> 684 bytes .../__pycache__/test_pass3.cpython-39.pyc | Bin 0 -> 779 bytes .../__pycache__/test_raw_json.cpython-39.pyc | Bin 0 -> 1433 bytes .../__pycache__/test_recursion.cpython-39.pyc | Bin 0 -> 2056 bytes .../test_scanstring.cpython-39.pyc | Bin 0 -> 5164 bytes .../test_separators.cpython-39.pyc | Bin 0 -> 1276 bytes .../__pycache__/test_speedups.cpython-39.pyc | Bin 0 -> 4956 bytes .../test_str_subclass.cpython-39.pyc | Bin 0 -> 1095 bytes .../__pycache__/test_subclass.cpython-39.pyc | Bin 0 -> 1523 bytes .../__pycache__/test_tool.cpython-39.pyc | Bin 0 -> 3193 bytes .../__pycache__/test_tuple.cpython-39.pyc | Bin 0 -> 1412 bytes .../__pycache__/test_unicode.cpython-39.pyc | Bin 0 -> 6598 bytes lib/simplejson/tests/_cibw_runner.py | 7 + lib/simplejson/tests/test_bigint_as_string.py | 67 ++ .../tests/test_bitsize_int_as_string.py | 73 ++ lib/simplejson/tests/test_check_circular.py | 30 + lib/simplejson/tests/test_decimal.py | 71 ++ lib/simplejson/tests/test_decode.py | 127 +++ lib/simplejson/tests/test_default.py | 9 + lib/simplejson/tests/test_dump.py | 249 ++++++ .../tests/test_encode_basestring_ascii.py | 47 ++ lib/simplejson/tests/test_encode_for_html.py | 38 + lib/simplejson/tests/test_errors.py | 68 ++ lib/simplejson/tests/test_fail.py | 178 +++++ lib/simplejson/tests/test_float.py | 38 + lib/simplejson/tests/test_for_json.py | 97 +++ lib/simplejson/tests/test_indent.py | 86 ++ lib/simplejson/tests/test_item_sort_key.py | 27 + lib/simplejson/tests/test_iterable.py | 31 + lib/simplejson/tests/test_namedtuple.py | 174 ++++ lib/simplejson/tests/test_pass1.py | 71 ++ lib/simplejson/tests/test_pass2.py | 14 + lib/simplejson/tests/test_pass3.py | 20 + lib/simplejson/tests/test_raw_json.py | 47 ++ lib/simplejson/tests/test_recursion.py | 67 ++ lib/simplejson/tests/test_scanstring.py | 200 +++++ lib/simplejson/tests/test_separators.py | 42 + lib/simplejson/tests/test_speedups.py | 114 +++ lib/simplejson/tests/test_str_subclass.py | 21 + lib/simplejson/tests/test_subclass.py | 37 + lib/simplejson/tests/test_tool.py | 114 +++ lib/simplejson/tests/test_tuple.py | 47 ++ lib/simplejson/tests/test_unicode.py | 154 ++++ lib/simplejson/tool.py | 42 + 100 files changed, 5880 insertions(+) create mode 100644 __pycache__/__init__.cpython-310.pyc create mode 100644 __pycache__/run_align_pose.cpython-310.pyc create mode 100644 __pycache__/uniAnimate_Inference.cpython-310.pyc create mode 100644 configs/UniAnimate_infer.yaml create mode 100644 configs/UniAnimate_infer_long.yaml create mode 100644 dwpose/__init__.py create mode 100644 dwpose/__pycache__/__init__.cpython-310.pyc create mode 100644 dwpose/__pycache__/__init__.cpython-39.pyc create mode 100644 dwpose/__pycache__/onnxdet.cpython-310.pyc create mode 100644 dwpose/__pycache__/onnxdet.cpython-39.pyc create mode 100644 dwpose/__pycache__/onnxpose.cpython-310.pyc create mode 100644 dwpose/__pycache__/onnxpose.cpython-39.pyc create mode 100644 dwpose/__pycache__/util.cpython-310.pyc create mode 100644 dwpose/__pycache__/util.cpython-39.pyc create mode 100644 dwpose/__pycache__/wholebody.cpython-310.pyc create mode 100644 dwpose/__pycache__/wholebody.cpython-39.pyc create mode 100644 dwpose/onnxdet.py create mode 100644 dwpose/onnxpose.py create mode 100644 dwpose/util.py create mode 100644 dwpose/wholebody.py create mode 100644 lib/rotary_embedding_torch/__init__.py create mode 100644 lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc create mode 100644 lib/rotary_embedding_torch/__pycache__/__init__.cpython-39.pyc create mode 100644 lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-310.pyc create mode 100644 lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc create mode 100644 lib/rotary_embedding_torch/rotary_embedding_torch.py create mode 100644 lib/simplejson/__init__.py create mode 100644 lib/simplejson/_speedups.cp39-win_amd64.pyd create mode 100644 lib/simplejson/compat.py create mode 100644 lib/simplejson/decoder.py create mode 100644 lib/simplejson/encoder.py create mode 100644 lib/simplejson/errors.py create mode 100644 lib/simplejson/ordered_dict.py create mode 100644 lib/simplejson/raw_json.py create mode 100644 lib/simplejson/scanner.py create mode 100644 lib/simplejson/tests/__init__.py create mode 100644 lib/simplejson/tests/__pycache__/__init__.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/_cibw_runner.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_bigint_as_string.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_bitsize_int_as_string.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_check_circular.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_decimal.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_decode.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_default.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_dump.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_encode_basestring_ascii.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_encode_for_html.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_errors.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_fail.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_float.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_for_json.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_indent.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_item_sort_key.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_iterable.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_namedtuple.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_pass1.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_pass2.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_pass3.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_raw_json.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_recursion.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_scanstring.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_separators.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_speedups.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_str_subclass.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_subclass.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_tool.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_tuple.cpython-39.pyc create mode 100644 lib/simplejson/tests/__pycache__/test_unicode.cpython-39.pyc create mode 100644 lib/simplejson/tests/_cibw_runner.py create mode 100644 lib/simplejson/tests/test_bigint_as_string.py create mode 100644 lib/simplejson/tests/test_bitsize_int_as_string.py create mode 100644 lib/simplejson/tests/test_check_circular.py create mode 100644 lib/simplejson/tests/test_decimal.py create mode 100644 lib/simplejson/tests/test_decode.py create mode 100644 lib/simplejson/tests/test_default.py create mode 100644 lib/simplejson/tests/test_dump.py create mode 100644 lib/simplejson/tests/test_encode_basestring_ascii.py create mode 100644 lib/simplejson/tests/test_encode_for_html.py create mode 100644 lib/simplejson/tests/test_errors.py create mode 100644 lib/simplejson/tests/test_fail.py create mode 100644 lib/simplejson/tests/test_float.py create mode 100644 lib/simplejson/tests/test_for_json.py create mode 100644 lib/simplejson/tests/test_indent.py create mode 100644 lib/simplejson/tests/test_item_sort_key.py create mode 100644 lib/simplejson/tests/test_iterable.py create mode 100644 lib/simplejson/tests/test_namedtuple.py create mode 100644 lib/simplejson/tests/test_pass1.py create mode 100644 lib/simplejson/tests/test_pass2.py create mode 100644 lib/simplejson/tests/test_pass3.py create mode 100644 lib/simplejson/tests/test_raw_json.py create mode 100644 lib/simplejson/tests/test_recursion.py create mode 100644 lib/simplejson/tests/test_scanstring.py create mode 100644 lib/simplejson/tests/test_separators.py create mode 100644 lib/simplejson/tests/test_speedups.py create mode 100644 lib/simplejson/tests/test_str_subclass.py create mode 100644 lib/simplejson/tests/test_subclass.py create mode 100644 lib/simplejson/tests/test_tool.py create mode 100644 lib/simplejson/tests/test_tuple.py create mode 100644 lib/simplejson/tests/test_unicode.py create mode 100644 lib/simplejson/tool.py diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb9e89a66c561ef9d4e6d11ca5fa27618430301 GIT binary patch literal 314 zcmd1j<>g`kg0EYar%eIUk3k${zy#zt0CBMlkVs)jVa#F3WsG9XWr|{AWJqC3VNPd? zVoqTRX3%7P$p}=U$#hHD&)>y0-r2`7I5^(dF(AOx&pr5-6j;#3GdRGkB?8x$%&6&$xy@!(gz}b zMLS!?gche36~|;2XJ+NcIOpf4Rfc-TmuKds! tQ&Njzg7Kknn`7eRGxIV_;^XxSDsOSv=d7hbl%`O&i5PXxcL{S7SQY8441c{_bQKUtIv?Q9=n(}yw9smpO1DqKU z1Tu&dQ@Nzd73P?7xe7cb^3}JbQabKDHl}@A-$>f#-JP zHD-TqO&{JS#O9LX7oXk#s4?-$r}w-Y_wU`gef`=OA9;6ff9T!)WE<#N({@r^jaqHo zs}GiXZU0e^tIm#5s;p+V)|JV)}|BV%bdN|K?@}p=C(;aAf0I^ zDR!FCQsS(%rVB|S^8BPczZ&%x77{!7#tP~X`a|>jyASV&eh@xvg{|e@!|T1y!utK& z-dd~6dTXJ#(hH*IqqeUY59e3IsMqnjJ?4js=-uzOu60`|&3_oIcD-i1wb+Gf=$~F$ zPe#H{uNN(~x{KaIb3W}N^&nYR-7PdQOs2BDwq@t>{qx_&omHQQcVn`yF7a82037;f|U zX&o&9>kKM@fl>amMGof;8NMaa>N} z7g8f!%ds81_HDh+=}FlS!_NH7JWYV=LoTMl#@)=2h^#HZ|I*&HG3ss4i3+j3jIqI- z$4+nt+)`9tu7Hp6v4z2i6voF^HAc$D9I;iyG&#W2wwA8Pj-bM=#krdS)neoP8z-9K zeAMdr)8W*qsW(ngD`m3nKlOR5Wxj*zHy6R39!x6j-r^#rwWPY}M|Y6+gQU6=^j4aS zmq#Q=9w;CnYo4o+DC+3W#9{N!~=OTB9 zs*n_K&)vM6)S8S*3)97N%VhIqJlc5_A=s9r1 zk9?Z?HMbM^=>$mKC$gSnviC60er|nZ&rLf)3DjhmEO>qh4>hA!uRCq?5o9!scgmka zp3muZXJfBkbbp4SGS!Dhj*B1fdgsv*V|&YbX>Fq0E!z(Zo6crovluyBg_qU~=dUbu zE^Dh8+b`{h&SnX-ONrV0t*Eq2D;w)CXx)SBLSaAVp%@c$mbc)53Jq?zvgu;JDKV#H zL8d~J`eBTLSxo$_AZbp9rmPoX9d>vfYjJJMjVr7;fqCksvspuld2$0$HLAq}&+)jp zVSROCvo7*+eZM7Xhc^e~0m{dN`z=;_Tnzp?8eqk^4$FgkA*Hyp){_lZ&gP!1pw!Z1 zYjX(nQapqf^lMSc)OMJcI&9Qp*r)~1r?vPuqZX2jTKpqQ%jBN?mh!UoMMCJQ9sFO0 z<|WULGbD=+Netqir^r*a#Enhliv3hSJc*1}>#0!YJf3n!`4~kXp%!nBM?-!Ji_xoUc1jz3AJozv-^o@d*%tXw?DKmU4V z`&)drX}aB}C*&f2_{?0DiQoTjo1VqnG$TpUxvSd~$OUDhKcveOT(GdO&SRQyF?!0B z!_`N4NMH$E#_L1N50er-18~D)+YhIQlj3|2tMF2|)LikC5>3KONu}HK7K0`W-iN`& z>8^m{uQi{w!hp(8DuFNMgJu^G7kZp0g;qC8O3ff>t|#TEVbq*|oD^AWKAN5sOPX>o zXx;8CCibnwevsJLgO8yW+#=m#ceN8H#m7FLdc}5c9#?wXLAB<4fuGo|#9nUH@b2&w z?;&{KXbcg10WS_7muR>upA?B2d_tywK$qLxK0PDMC&Dps9^n@*;o{nLJlBToL-r(| zZWH+Kvq$LXOyGXN9&()jb<3{3&lv$vQBIT&fbtgj`yjW^sXNE)%6S%;Mu3P?Mwg0AjH1w6>GD2^*H3p`gsu7t;k z8zv5TQmg^UYok*HW4W5Pmbcjq|d84sjx-lkrU$M&bVB_$TEulQIU(=sC|ZTLYC zy>bfZZFZkF!^XI{JC;&j_U*+ z@RPBqv*NKelK0kn++s43rJ*QJohRq14>?*}``|4E&u6R}ux<+IZIE{qte^^Bi(xPa zwA8E1akMAfND_@@<-Hd4D{=8x)~~rHdMbMD7M4b=%-olSt@2B2tMbxfRfIdrk-J>Q0>foB z^x-m=HCSfF4qo1{gxVTFUq;zaj&kdK{U6n%L0Virr}qMDYl!NrS#-HYS6XzV8mrSQ zA%7h#(70k)5@Bg2(nClO$n-GM^|;1Ha374TTy6+=k{QOG_#^QU=a0st3=dyEPRC@r z4E{LMqmn)$(?s8c^q8bi$~4jUB0VnY`(&Ev`x$AY1YBWnFNKBeQIO5KeQa_Rv=)n%adrt}XjW#JpS zz`(MZmeuj`cf>L~^=(F`EM?x*WZ=(qkwQy+EbntHnNpZ*xs+}Ted3R>Wl9;gSb6gV zd`f*2Kcs*TrGSt-PI=;hCOJq64whyUD96+B zNuYf~N>WWyyC`omKCzE_;rIKNBwyH6{I?8m?d7O_|5Bcp_Ho`@xxD?m=pEQa@8B+a zhj!6Byo=tEUG$FbqBoV*gO^XigL2l$@iKA9Q^NnIr%nTHrMuJ1N;lujZ^v&-FaJ+` zI({2oJ|(>y$8eaY$Z|Rv_Prk-^XLEV4@1l2X7ri%j(L0`_ zMq4$~PH^7ET;B99dM9_$`~EI^Z{_HbU*wBq zI#cVl%}OZG$vVSxm*B5cWP^{aeoq%OD{+M9`h6T6IkR2{mP$7reZZ1*1&4G6#~iQk z#P4v=6@%Z#m*aQfxl2+~j5%qSd*%}N+`D~Ced`(idxq!4*m*bCJ7Vm-m&+4lCtpvD z9bGRwGSt}7wX+h+1JVXOFbfZyPCd|Xx9|XQga@QA1isP(?*j>?oBAU4fYMEUq3M1Q zf51IZ3ZnS^_yc%gR^lo=AnkI5XSoNi^l_Kv3a?Jaf`8TUfI#v}F3JMQs~RthlR{G0 z&PpgGr41l?4M@J7BH17N0!iWsB>SVj-J88F&JwO~2t-kke z#kZsflKA8J7CdmXk9*%P_rOi=f!lrDC0}_!@PA-VpDC)JmA?W}}C zQrQ?~XHyh`CE=N(NF0G8aYU{ANPL3UDBbM(l*WMI2o&>l@5G2bKGX1$);J!YalAg)@RIe(@V_wd5^MC&b1@NX^e=LGVy&I8C)V0Z zZ(>w zmV9AJ+RX63H1HDrWh`3xEGq@GvubXyI7e_6s5< z>n}K>?ZiP!aOBMVGt^7z4l-R=a72ysbenOL&&<`}Z{wfEO`zbxnzYfkiztZ4COHa^ zBwyI1a(R7H{=9+0MZPM2l#9Y8&YREWUEW3Son7?a-9_&`O>Z=9=O~*s5SLiNE5tDo zXBtMju0T(qlc(#)EcSu(0&GalzBS;yAaP!ld|@f|UCLiFa29wi=3*)ETFT`Kyz=z~ zUQ!Rev;{oI1yV@8Z(v3os;%Idm@T6{N_WuE6&%r0vTEu+jvw=0S_}RVFUOC8UJKS_ z`F*=UuO-oIOTMrvYnI`64D>$WJ)@nA&=t<>Lv{c?CA4X5SicT#-0Gk$hn(^)m8Z{22?rE-izMBC+`>FJN=TFFCwiHsEfqNPar)t{@=C_@@q&XrOt+Ba2fnM=`{_q%PG zQu(%2`+z-v-_Tnv*RTBPsUoww1nJ>tpw(X&XCk&5Rz}2D;~5cKjb}t`HJ%Z%)p$n4 zR^u5FTa9N#Y&D({vDJ7+#8%^pRYe++)_D1T=tq<$V--97WBkyMHz_^DRJHPqn5xDz zVyYTX#D&u6rp7a3sv6ITscJkUrpkHLi(eU2-65g+cp3-R?B`=UDtbQe<0`tTW0c>= zNe`Ti9-^LFc}CPz;~7y;jb}tXHJ%am)ObeJQ{x#?PtMD}JFBBXao4wx*H89TKXoY4lOEFGrP8=eINc zLB;U&w#reJqo>-M#`d)GjM$#WGh%xh&xq}5JR`Qpd3j?djqUA_P#C50LCwBI!)U9F z-tl=wMcH(WNN;Cgt9-?nrc}Db2;~CMn9W7DwLmGYCA)(r;qHlWHDiWqs&u03A zii+tHawTE3g!EK@5PFD~Y3(thWg5?jmT5dATBh;zXqjGKe*0vcg8M{8v3Bb1^d;%3 z_6a>ix3sz#(JhTwVqK!5TY7oYQ${V$W&Ee2PCNB>wvY5w`-C2%QCeM$ zXq3j&qfxq^MBK0@V@});DoV6dZ)ZJoqeog@jOdZZ6L`w#kx?TRyU}YwaT@d_yt(z` zIJ$(2c<5Bp<9Mcei0^21G2%NK&xr46JR`oN@r?M6#xvqO8c&b!==D^w7M&{50xtcy zi!LE!I7X{;<1t!YjChR3GvYBC&xprpJS`q$)JR1`^xDgaOD6Ks&e7Z1KXW4{T3w8Y ziN-S`CK^vf0MbZ{#xo)&8qbKBXgob)q9dT97dlm?Zs8mj&#hTitvbjc$#_tJv3A!Y)San#7*TalX%S0tZ>v;x)dP(q*T>DIy;+WN;DZYs{oWjwhDgRj%;4mmoPfgLz*>Ge0&fRM_-krMfVCo+7Q=iff*NxG8 zK@XeQLlWc4Q~c29AMCfqp7?LEDV0vq(#G|Tqu1JPRm4*BsXz6|_q$U#k3|-Ib~?BM z>AWq%+q5xvn@$0358xs<-~Y!IY-FbW%;G|O^Fg39OLC9DNQ)iM+jNRm(P_IZSDJRV z=?I^q|4D|mZ2gdGLA%I439g%QC@ zT)uA|`MiPAsF5Hx=;K&PLnzUi5|QMmOd5kEk)B3r;M~Z9H{(51iFJ{+D&ORBD~t0>e17?FVPGnVE)5Ckc2VDg_IiM#ZRd#0c!#I@QC&AKSmT;yqmYPe&8e^%=J}otTVW(k`bi)_CJi`}q3}0x}nnCAWzr`Xa zTFleg784=UVxA@a^O`<#ss4G?q>mVh>iXw-`4=^PR&ST zFX{T18uhlf|m&z&JpZ!un z(^O}YvKZm`S0v&P4xf#~G`rTLvp0$DrIg6;hP|LMs*dMy0baBsp&Fi$${Z4-Ct6a$ z5u;AC)%_=A!hgZDY9Edu(eeHoj-4H_kMP6WII@icZj(5MhU41>I>L=ZOO80WB=STP zUpoFq$KkSHqPflxj=oLWQ}!`_(Cs)5yB)WSg;{$R2f|%@%9*kcIfrnFPXE4I;x+rV zUG!hg&FAD(sV?~(qRyiBu9L0#cTseHvGl6e&pNQIVe)DmcFG=djyr_z1dygfh2nd; za5)#@%O#yoExTOYotBl!rWIaF-{dJe$*Huemst!7;z@iUq{em zhnz8=9J1h8K9)4285nB105YT!z z*~1UXp6*~Rg=L(214)KWg40<^MujYEg;CIYgw-HR2C^9(kOc>4=B(fh*)EImFHS&? zh|K9V_1}N^KRgisNrcW0^M9Ga{|SRwcJien9na&4@wI7Ys_38Vb7#dr?)ZQf*WnFZ m+%gT{SAA(Jz*JG9>A)?xrJ-W&^`SxcPHoI}t)Gi;z4ZTDnTsj_ literal 0 HcmV?d00001 diff --git a/__pycache__/uniAnimate_Inference.cpython-310.pyc b/__pycache__/uniAnimate_Inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a67a795f4a10e8a4d21629fe28afc84266fef325 GIT binary patch literal 2625 zcmZ`*&5s+m6(>0}nvqu0ZfwU%?YL2k1c{Ph6C^zuu3=UC0ipKn!n=F$V1vnINXfIF z51AaTvx0Q8i?%?0=(#T7JsSNddg>p+YfpR5rRc%wdz7?V8z}`oJd#7c$M@dvM-`1m zjX-K;-q}@0JE5tveAlp<|x<%ek2B6(Y z^1hHFk3~4jlPFKJs1%{dOBul&e9)AmJdDz0kcY!Ui5Snh*cSKr@e7z0B%quGOjw+9 z_KZBGfz7SJ;dW5rPTM2zk&8tP5xUFa$Rfk$$S+uXrp zUf6wxSLd~ag0^CQ3|1;|@x?mh4bYnx#;bglukm;0`!Bul5)98}UwcwK9wkz6ZT0;9 zvtNS2vtOY28iZsZ+O~GN=trZp)K-?{Afg|%tx7S}?(cSX?rwH`*?Pp zfv=b|-R$FrFv&|HkE2xAz@^a4DP5OB73rus}{gZ=U_$0}B@kD_^Sw`QdVzF=-k5pM?VP0^d7L(yJhlfx>^K^#2MOm#6 zPxMO9-#yq5_aE$T?$yD(x(b3>YqUn0`K?j>+5fSB8edlj@PhE6=ObdTAdy+?~Llz6E-Pd(YXQ$sb9dVj({BO1JbT zaLQA<0kuDEz&F)N1F`s=O&gQ?q|s-HMt>iWWv4~N<$K^<@=u$P@s;}`;wPI#SfnUHTQQ~D?JlAMyivd3WAaI*Jca+(3cwl^>#Lq4*3$a1jqDy0%1zuFMhR zU=ZdA`4abG1mq_m@F-KYHEL7ObdP>uU!~V*6UO8B{!cjkS5@tWVU8dTbuA3Df{#+v z>tXnK6s7Z(l{k%*%0&6F;PPWIrt9}N_Ydy-Gtg*n>%i~s_jZ7%x}E*a?VbA%j9*JM zFHBrAd zf#HK$BtC_`G2Ud+98jX0t=cw}1J|!1s zEwjtz_UZYzxefbgT;aW}B3qoYLpHUhHgJYz5MJ@e&F;ZQr*#bU3Iu$!HC~0W?_Zc` zY>49|7A-7+n=J#9@w;YIqFQ-TwxZ)GNs)j7G`|7=wH=Rmr0a+ND;oxbGHE}UYlwlXZ)I!F2)?|HCYC8*=6wqm5N8())teV8u0j1^ElbqLH|9@E z%dmaxNm4#+Em1eV0oH{C?rNDE7#dwJw^vLAgE~BisalWC-yg$7!CF}qFm+D3YDP-L zL~3bYU-ox4HpA}S&fZ>lud}<`^SAf(H8Z)<+uOa{c@X-Ydz;C~^W{d#< literal 0 HcmV?d00001 diff --git a/configs/UniAnimate_infer.yaml b/configs/UniAnimate_infer.yaml new file mode 100644 index 0000000..d44fc89 --- /dev/null +++ b/configs/UniAnimate_infer.yaml @@ -0,0 +1,101 @@ +# manual setting +max_frames: 1 +resolution: [512, 768] # or resolution: [768, 1216] +# resolution: [768, 1216] +round: 1 +ddim_timesteps: 30 # among 25-50 +seed: 11 # 7 +test_list_path: [ + # Format: [frame_interval, reference image, driving pose sequence] + # [2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"], + # [2, "data/images/musk.jpg", "data/saved_pose/musk"], + # [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"], + # [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"] + [1, "data/images/i.png", "data/saved_pose/dancePoses"] + +] +partial_keys: [ + # ['image','local_image', "dwpose"], # reference image as the first frame of the generated video (optional) + ['image', 'randomref', "dwpose"], + ] + + + + +# default settings +TASK_TYPE: inference_unianimate_entrance +use_fp16: True +guide_scale: 2.5 +vit_resolution: [224, 224] +use_fp16: True +batch_size: 1 +latent_random_ref: True +chunk_size: 2 +decoder_bs: 2 +scale: 8 +use_fps_condition: False +test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth +embedder: { + 'type': 'FrozenOpenCLIPTextVisualEmbedder', + 'layer': 'penultimate', + 'pretrained': 'checkpoints/open_clip_pytorch_model.bin' +} + + +auto_encoder: { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt' +} + +UNet: { + 'type': 'UNetSD_UniAnimate', + 'config': None, + 'in_dim': 4, + 'dim': 320, + 'y_dim': 1024, + 'context_dim': 1024, + 'out_dim': 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'dropout': 0.1, + 'temporal_attention': True, + 'num_tokens': 4, + 'temporal_attn_times': 1, + 'use_checkpoint': True, + 'use_fps_condition': False, + 'use_sim_mask': False +} +video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose'] +Diffusion: { + 'type': 'DiffusionDDIM', + 'schedule': 'linear_sd', + 'schedule_param': { + 'num_timesteps': 1000, + "init_beta": 0.00085, + "last_beta": 0.0120, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', + 'loss_type': 'mse', + 'var_type': 'fixed_small', # 'fixed_large', + 'rescale_timesteps': False, + 'noise_strength': 0.1 +} +use_DiffusionDPM: False +CPU_CLIP_VAE: True \ No newline at end of file diff --git a/configs/UniAnimate_infer_long.yaml b/configs/UniAnimate_infer_long.yaml new file mode 100644 index 0000000..9b5f368 --- /dev/null +++ b/configs/UniAnimate_infer_long.yaml @@ -0,0 +1,101 @@ +# manual setting +# resolution: [512, 768] # or [768, 1216] +resolution: [768, 1216] +round: 1 +ddim_timesteps: 30 # among 25-50 +context_size: 32 +context_stride: 1 +context_overlap: 8 +seed: 7 +max_frames: "None" # 64, 96, "None" mean the length of original pose sequence +test_list_path: [ + # Format: [frame_interval, reference image, driving pose sequence] + [2, "data/images/WOMEN-Blouses_Shirts-id_00004955-01_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00004955-01_4_full"], + [2, "data/images/musk.jpg", "data/saved_pose/musk"], + [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full"], + [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337"], + [2, "data/images/IMG_20240514_104337.jpg", "data/saved_pose/IMG_20240514_104337_dance"], + [2, "data/images/WOMEN-Blouses_Shirts-id_00005125-03_4_full.jpg", "data/saved_pose/WOMEN-Blouses_Shirts-id_00005125-03_4_full_dance"] +] + + +# default settings +TASK_TYPE: inference_unianimate_long_entrance +use_fp16: True +guide_scale: 2.5 +vit_resolution: [224, 224] +use_fp16: True +batch_size: 1 +latent_random_ref: True +chunk_size: 2 +decoder_bs: 2 +scale: 8 +use_fps_condition: False +test_model: checkpoints/unianimate_16f_32f_non_ema_223000.pth +partial_keys: [ + ['image', 'randomref', "dwpose"], + ] +embedder: { + 'type': 'FrozenOpenCLIPTextVisualEmbedder', + 'layer': 'penultimate', + 'pretrained': 'checkpoints/open_clip_pytorch_model.bin' +} + + +auto_encoder: { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'checkpoints/v2-1_512-ema-pruned.ckpt' +} + +UNet: { + 'type': 'UNetSD_UniAnimate', + 'config': None, + 'in_dim': 4, + 'dim': 320, + 'y_dim': 1024, + 'context_dim': 1024, + 'out_dim': 4, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'dropout': 0.1, + 'temporal_attention': True, + 'num_tokens': 4, + 'temporal_attn_times': 1, + 'use_checkpoint': True, + 'use_fps_condition': False, + 'use_sim_mask': False +} +video_compositions: ['image', 'local_image', 'dwpose', 'randomref', 'randomref_pose'] +Diffusion: { + 'type': 'DiffusionDDIMLong', + 'schedule': 'linear_sd', + 'schedule_param': { + 'num_timesteps': 1000, + "init_beta": 0.00085, + "last_beta": 0.0120, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', + 'loss_type': 'mse', + 'var_type': 'fixed_small', + 'rescale_timesteps': False, + 'noise_strength': 0.1 +} +CPU_CLIP_VAE: True +context_batch_size: 1 diff --git a/dwpose/__init__.py b/dwpose/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dwpose/__pycache__/__init__.cpython-310.pyc b/dwpose/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0d87428c3de4f1669b295c96c12d76b7616a0b GIT binary patch literal 191 zcmd1j<>g`kf|lOtX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o11$*(xTqIJKxa zCbKv*D?i3LKR2y1)HA+3GcP5-yg0rfzo;ZJDJK;s5tCe6T#}y~pO>GKS_~7656#PT w%*)J8EJ=+?DKE$`PK}9=&&g`kf|lOtX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2D9*(xTqIJKxa zCbKv*D?cVQFVitEGdHm$H72FJAip>@CO$qhFS8^*Uaz3?7Kcr4eoARhsvXFb&p^xo E07mp4nE(I) literal 0 HcmV?d00001 diff --git a/dwpose/__pycache__/onnxdet.cpython-310.pyc b/dwpose/__pycache__/onnxdet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..486538970c3885adf07739d7dd14739e2f894872 GIT binary patch literal 4220 zcmai1%a0sK8SkpuMRae#T zdvw-lgbvD+XTR#ES znDN5$?u>UW@6CA6^1k-8IO|Ap#?iiZOKQm(S2R0t4y|tKKNZj%^qg^32OCWCO;>%X z14-+Jx_ZEMc*vMjhyzy(9iDSg2U5^iopa7PhkRVKW3C-@Um2@o%)?jEJnxmUh8^?m znE%RHlXS|8Zs=P1)mUe4(%01WS>L>#)AQKTf?m)StEqBQ6ZXAu);aHt7ctk@i<)n; zL+lZIJEa4CicW@hKizr0s;ju0VDI9+B<=O%o(sxWSz?8U~9hQm0OhNnYQ8ALnDU^D=c zq?YiuAH^zeEgA1&wiD;Z&pVkyTS*6byLhAw-}$cLyLfl5puJ+eNX1cZyiCbh8Fw>| zhlV9a>_ATJLJ-syqOKUWZP+8@CD?>ABHMy^lFIyrqrQP}{YP~B+WP%GR{468CmY%N z?QGE9y?>{@4WVp1Zx6AD=wUye4y<=Zd65m;X(r=*+S$IJCbzJasEF6)_Atxi^(;+y zWL&Haca6X3H&l!T7gy_&-*fCtSBgW>EqCvOPe=VC`9Hv0zD>aVPPC27U*5u%C0V+n z2!bUfTUF!bkD_60JQbzA*ob^IuxDsQKTcH@lE%$a0ISx?(hhDb#j67Jx0>L;?~+Jh+HH1!ib&`@#~ z>FP8s6-?CR#-%kZFw5#48g33|htmd>2o(7xR0N+opcJ6yEa24E*=gp|`*kSz=XKC^ zroIQ^k6&`Sv_rxXnQndwrVJbjizS|iJ)r41OTkP2oQ=I-I$fYZx=#Y_edz47v2T04 z6FrrJoNWc@6F}d!bukXc)p1y`1CCed>LIgja0q0q!+iof49>Nt>s>DWLw?{<=l~^v z8**C~H@dM=L9p@3?T(wklaE<&+q(Ij$;ub(t^-F6Gy-84F;dk+R)9QVX}D6*6fUZz z4zdGHVZmLZAT#S-bs!(w9$$t}ZP)?Ar&5!?Rm=faYiliWM2I<>r)n8h%T?b;qu!(H z2dIv2qUw-?oE8Ep+3_j7e>sZ^dvHDr#_U%Js`x7QhVun~Ot!$RceS_wf*Zba8P2d! ziHu(`6C{IPo80LMaD=P zrW0QKiRuGr!KaY2a;RobPQfAu8kDDq`g zF05+_&U%rZ=5Mk$;K-ZoZB~c#zQ8WRxvTpSdLfiO8A5|So|o*9!otj-{~ZkM$k|}< zlnnm)T){U4`Ef#sM9LyrSD+6G*E23Xcwmu+rOeB{g*TZ6T{QdrWIpL8J?EW$SVrM* zRJ2nB(_TXSKAZ5grscaFa#iU+69=@GdkBv}DVChhoJk^OoAzOiayD5LEJqyFH0%_X z$v5hn>IUrr$(_(hJB*K1$u(>UA7=fifa8b5{uap-Tr9p6 zxkTfQAoacx$SjQLY+W%vz(om&K)XciTU?v-Z)gyMSR%`c1U}9CM1=W+8$T zPm82C%0_vlRJ3a<-F_AoSFf~u=os`&n54j=HACb!+tJX3!$=bTXp|{aD_6992_{N3 zlm$~bRPn?he~QlhZB$LbWr<(pE}Zx~fRy{Mz~%qE?1jh_0;V<~(xC&HmPHUV%=`_w zl5EU@8-}P&zEzOVJ79?MxS0)3H$@NK&T-v!&p8R*q;r6A?ZH!0eD^KZpodzd%;bH` znWFTI$RDPChzMohtj8r;^0Gi?l`;kme~*MZm}s$MGtQ>#(mxM(rNd(hpk*pwu^xNM z1;o7~4~Y}WKh=*lPsqPU{|R}l3HeaFZ{f$Fpj@YlxPf1rXYRO$jK>)<^;c?N zKBZZY{biYP_J>6XF2?n71M#dO>&VFFbQ3=yd^|53viXE82& z#cjT$-lfs2w*2A7OOb*67tOtgH$M(ufEs|OS%!463|_}*Zwd87NUJlb4BI#!ux4N* z#~zK}M74J@NxMk@(@sn&R@^Bv^}nK3E3~#?!d_e;>7rm@+%y`%w$&)L&bACRRZ8*N z{bVz?nUDH0eFXj(V9l*XL5vD?k26LS+H9`9v%9mau9AvcnKG8YGxd_08Skp<*YrI6i18+y&B`oM$SkZ$9E6rtl*hibk#;SVY%VRYtj<(>-1bb* zq`Jo$TQw4r6Tu}C%LTz{oLl|`_yagCCvGinT<{qQAp}}}U-it`F|3g2R#(;c)mLA= ze&3^`dOfi4JbnJFt$e|<{!GN_M}WA4FZ%?Av?S|TVOg_~8(M@yaw(o#p)GCcpmtUJ`>RFx2;h{`#VhXUHcbI zb+j*O#z0rbTn9&tS-BY7TIk?{h0>RT)^)+UU>)&M)%4k>&py*vL!X0n&^YHzU)}V% zrq4ao*C3tJ({)`f*Nu7RCcK8O&3N;APS0ac3wl9&Mw5416L!9E-nwXw7BSY#>ea!Kp>}c>FbdW9FMcGgpQVQ0xfT)0ucAE&b5X}|Eg(J<)_x@aV+ zX*f@|V-+`-3g=O87-xl>wR#G*mv*zwd|MTKc)j5J_=an!*Fq;!ag-HKPsvyn_HG>a z3zmdpIJ_Q;{e2X};TpBp3ihO6+l7;06DkzFJ%}f%%wAgRBG&qk?|tyz#=|UD*+!Bj zJH3sEX>u2fit>0vKI!+ec%zr5LmB6*{ry5*)(SIG*Hjr{(BGrp{>~tfQc({a;BX__M!E~FGsG%l}Y}LYu5nV z!&Qr2d*o|>%mHMv%hmgjmme>K?_+F5Tc!?j09IEtKNJ9|t*eI&vbI{zov}Ni*%m)y z+69Da-~vLr3aD0ewKO_iYXj;4UtQZ{U)RYL9xr38ZFF()^<(GjaT608_B5st!OD@miyCax1B(?NIBEHutSuhq#m zfr~JpMhU-gelO}IvZOYRXs9^e1Yv!lpS4#LD6gT1lfjJDg2G z={w+MS5R1hPXGYT12Q#cbKlf8w#;m{{IxA=#AWKcko@=+w@LdW9FYlTS7DOCkuX@| z*-{&5dcjf`HGjcI&d;nZU_jbWeeHZ`9kP*YIL;}Kmy^5nfIb2AsWZ}nmP_)R=O&@|N+3`7i*?AP$gZ0Jy zi2WR45vyXaTVL|WHuN*&CC%Nx!;V0epCb`8^&jyKEu&S=3xI>LK zDhOc4K$U~m>Mv`9`>7ah1VP8{Q(T-8kiyP|EO?CR~F8rKT2f`D>SEvm@sLC?lv7^7P?Q$+A(zK z7*eg0c0SxgxY#RfQfB5Z_iMIlEOP$r3BQMN*$N5^7H6|1IPgWj45$7&EVRMib>#}wjyeEU#Ud#E8An#7Nkw`@(+Y;a*k$TRh z1FtO7rBdcqZthG*K^Kj_I2lj6NzX;=5cZP0JD#?3f6660IAjx=<}~ychg?Ov&&4>P zS&rZ*D0Pz4kuyo8VpAUGC`Xfx!Jfpps$spbSH4rzRMu$^NZN!(dIoFo))gV9vAm%j?%oFsGtysFa= zy>vS#<(7THO6Vr7G5WOw|4A|5HCTfls*wtlPcBD_(pw^{nDP)OO5TiPlPvjOAhSYQ zghmuVn(R-sn7$ch(s1csw9dn~pGW{LQ+dz$@hRmKbw(Z%r{upuKjOetKju6o{|)+2 z$zx8)2im@bbU&y3r-EpLUzX?gsERDe8Zi72{DoStp3{tDenw`L`4tfsM?N*GBhJ-j z4Vl`UZs6B~kLG1vHlC6VA#Fe8CIT<$nqI(N)Y1SaT4`H2(VDKoV&>MH{P?FphI$A@ zl!kKytz+`P#_6=tpyy8Mm&v}`fB1@?KizKs=GXN6_I8t-97tWJ-W%rm26=WG z`vBjmkOrs$c$)R_-%y6M6SOyldI!>K8AZW%PCASkn8W8!l{13pIn~R7J zh>~z1h1*lfJ(z5DV$dc8kyTPk2^SLv_&J$QhnUP4mFeOES|FI!T5+HFT5g5yqW4vtsc-Pw5e6$|y2ZBx>-t zd5tgPx8f=s$F<4p!-WLwU1W|^8~?^COyK0#GD$aVYs|FvK_So}Amh1VYd literal 0 HcmV?d00001 diff --git a/dwpose/__pycache__/onnxpose.cpython-310.pyc b/dwpose/__pycache__/onnxpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c795c706a10ccbc4f3aafca0783ef91e1134d96e GIT binary patch literal 10311 zcmcgyNo*WvcCM{=vPsdDL@nbfZ*q5rq$Nt@$#68bWO-98(S{_)6N+-H*nf4isNVcn zQEUp;C&@vOz?f5zD>avaZ#e`(fFPJ4$tAbk0u&Hnk{kkIE;%HZ@x;pa-d|l^-K3n! zAV?8i_3!I@?|Tco)6+!_&!aE?r17(5P5W2Mq@NjNF5%~A4NVi8&|6wtU)S5lx}j6v zXqoFKf35Wlf3xekf+>DmSZ-!JFEXy-7TiL^IH8Fw(kL7L8s-VDkrg?Sf0S7-ih`Tl zF2%jF)CP4$QF>&pPZg%b_r#ia_?N0M~{g(S2_CWKN#dpLaRy2?EJc?Bv6-)T36Hc8f z@3JVC#cP=Dm^jNboa6UotE^ZN=h5o;|Iq4!_%2#4h}XsU9%<_*gmF`=UVK6nQZ*v$ zn&$^m=GIQH<*t2cJo#%BKhbnei%hTGjEY{Tw-eNS?|~a-r5o(XPTi!g@iT{S4!7uXgZ#g}5 zrMC2pZp*F&bfptL=pK`9c4c?b-m1nE-{8wl&CUHsOIfm|OWj57;G#ZCHt7nrw$YN7 z?7G_tTsg6iE2^;DIN08=J1uvj`AsexZ@u;e@=TLBXYC2Uh%lS@;rWj@t9q2{IYQvr zt7h{&J-_($QuD@NJpcJGj@*e8|NMKGQ2Dq2S=spU%JoalAO8)du3b`t{OPqz)qIrU zBaY0L+leyL=``KQ==7otUr02yx!LX4`g{GomAW6z1l?Xubay*7r_t~_U>T>^b339s zCy7a-d>6b#J!GD;1)O9}4WFYLW;jZy6N?Iz=F^CDFEaXPBV!N0{yBWlMHUSdW#hIP z)vlDfq{wQ!PA4+`KzymmGnnz0n^#shKK5PdZ+O19-QBp-Z8!ElzFOP$I#{+}>vd(| zY_{BC#YTO{54!DI2ej%B3u_;DyvrSsVBl_u-5zLpquc59DY4w!i;5$Xr%n4bad~!D zpVE)%^j3d!`m8x;md!HqMWe`XS)o(0d@xI$u6FnokjUc4LMA|>sSS)kO9_hC1}206 zgv7o*F_i<-z&RGs1O>qXkL8fv}7 zX*Eb>;u{e!&?HQ<>h&rwC6SdEkoD*BvWm0Dygp}IpdV0@Ii*{MJdcWh!;hXzas}^b zmW8dFk-<#9fMWR~)gAG@cD-Kndo3@hNhk2S^$f~~&lD1rXX5)qBtSw;MofU(L*N-m z=`2V|BcuYNh&1d9#{L zv1L}cb@XATlSe?6k~FdE6qCFBE)^scE?-ABVn)P@q5^RwuR`%;PMuFbiW6v&EIi7p z0$Tv_Nv*N$MDQ)EBG+qa!`xt|f3OO)B zNOGu&Ga=+R{e5nbYCM3{-#0=qp)jGK85`ChQ7yqRW|#q6Sz-3L)-)dL!kX6xIrKnz zXhMz4hSmbtggR;<*LjZ1VNQ_pmh5$%QL#b&71yVMlxAmFE-Y4U zD2vx4>+} z(uPwHFj26DZS8ukmc0oTb*B}0Ee}c^_Qga!qGX!)ZMWf|$*O&B`K;RK0XR}xq|snUSlEPbj0M)OZEu6ZZoT%U%9ht$5H~v4WU5nn%u_dxK{YMk}cCB>F8EQ zr|^uyl`K&j@ozLXL+4u6=JC_IlEtn_ulX~9B>-aOcMN}lxi*?E= z$M3oIA?BqOE(>FjhRT@tJ`#bpuY(cwCb66`Fa{<{z@lI{Fw#8>ITJa9P>vXn^IR92 zlz(gvU>O3jj0eAbryoEi=F4M}BH98?*SEN+>D_lbAZS{d63mA}m+CpLb%mpar$uvnmN>x5h< zv=UTFQV|ImkHg&(lIy{A{t`;zt*#A|xxUk)`yoN8EK=e$>EVEbE$O-t{+{$n+{KEF zBSHS|XFDD=k!4V|=d&nn>~!jUz@BgK_zDi7`t!KwWqf*zw&LC4o%$e)Gxlao2*e>t zQyb$5mn>vKG^6ehkZ>PN4-C5N*h}{2U8?&2Su|{7scIn(Os0CPTfevK`R*dt!k|g` zM_W|mfDRP297=-D1O?K%&iO%100iY}^i5cofMoNKt}BpaiuMof%BOb#6DK7{axSsT zM@E=}3qui~A#0$+kE!@~@rnup7sNLb>Ger+E7!;>!#(&XSqVMqYw~o!2cI!Ni`#R5*8l9LW)as!&3P1V6ux0-Wieo!Wc{u`G@EO^==9# zeoo{_A-qUzQAF9O&KT9f)0b)sKjRb>$++hU?V*8uewe?AeT>aG4W+OETP;s-NWkgz zti0#3dhDcxbxej(EZ@|UI+P9D z4C~-;{09K|mBTom4S1zDl8jUW0Q>?yuMTOw>&PB+4=be*01gu^L?N*h+v_o^0;j1+ z5eaN1vgo+&xX&`y>r&W!8*Fq3)>f>*4K=rv(i_YCk#>0i--lz|96ADNDVwPwcqPQ9 z^z6#OzP;;-U<)2jm(<}qqp?BiBkZ^_jvW;xZBjO8iYLe^o7t7KOXtp>eT7CLwzK)tAiJrL(nn}Zv}fMcYid)Y zQE1b#Hjc5D7wxmlsXc&)?aF4i+rs!C_IOZSCQ9K697BhMR_sK=jL%U)rysuSw0w8f z_GNwP4FTB+kxj+7$;S3KRl`w3!F4>;kJ}vx3DTaQ*&DGzYPHnccRD zy;q${DcpBx#!u;0qt{)$;2f6q+AbF02gA>n(gQFd3~JQkAh870Q^G_ERC`2R^2B$C zX$vzXy^~6>|M!^3Ck3Fmz>2(HG#7O8o-XL8^<{id>u2z_j0N4Y7QTl6t4Av;U=7{@L+{mq%y)n8R?J8QM?W=gCP{$AT5&bVMG~+ z3L2J83|4l9{3amoTZzab@iesKc<>P&k4&Jke!M6pcVYI3RoU=E>6p5)LjwKy1Y zCeg$H4jJ-JZJ1;Nb9@185t{R05kN825SYoMJc60qCi|~UFi?hb9GyTei<}AvkrNkxo{F<7tK?;ogP+*`Fj zIZWkYak8*D4kwu5jaIN4uFU#N>fTD7bD1#>J$R>6O{3OR0!lXVSxFk5j0H}c{ zJK@JjUyJk)BI6T8JNq>{xG1})z82>S=upFkYW*ll?jic=`1c}<_C6sGq9cT{>Z*t5 z3`52zciYra@}p2(@+2ZkWj#+ZC5kTL3rX<@XfpN>5F{H+{I1}~0{}oF2Og4}j3fjc zDBohsJhYhebmy-y75x3mGUr7=+45nPZ|<;Ww(A93_Iu6_qJj>hf_%;68n#+_KTV{! znT|=xR!AC$5@27e89@I15tgX|#;BYJ*!#u-_9h2Nh0JUyhC=}6;*l4PC;NXuxJ0fb z9Fpb=IDwf(60Ug$*u~+Qw75{@P6H@%2hb52R5K$IOx=vBQg}p;D2*bK}aS2%{8T1Q`i@~_zphIq#S?e@F@pws#DNlRd8A!Lz^fIt>tRF z$q1~lm;+5OQU4d{^&-6};QGv?h{{P27w!q)b?8)R=WtX*t1aROF8d=C;FL*i#pid7bzi2?7A1deHvI8dn+89+>LPf3{){aXS8(iCP#Xjvk(9LHQFxyseQ z8_Z6I=lK9gZ$XHV&w)MXN_8P;oiZ6LGN+WPZIL#pbO@pgNjJ+|NRa^(1nDAC{vtSZ zr{%tZoooK@<29=(AH88(Pc7YeTEuV8D1Lo<$*{}6qFm81zAl=^e;JnXt8%9JmGN@% zse$scUVNH=kx~{?C+AMj7{?90_!W&}#C?kBV`Pl4Ec7eCPBq-?g<-0^MER-t{{|i` BkOlw% literal 0 HcmV?d00001 diff --git a/dwpose/__pycache__/onnxpose.cpython-39.pyc b/dwpose/__pycache__/onnxpose.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b9c5ec3524c064e9fd55b0a403c7bb28e2dd47b GIT binary patch literal 10234 zcmcgyOKcl=dY?C+l4wP-W#_T88#^UiNwMX0+f}oP9VbqVI7(nU%`Rhy9&v`$$P~%^ zXJnDa@KQOK0>SoD^pb?(h5l!x@s2 zgYBZ|2%P!9pWpZWzK5K-xuSyS;TM0>`stdY{2OH^p8_(M@C*K0RTNio)sC{O*416D zuBnvQI(l8_uTeMoYt?h^Z+hlV-qXB-S7>Rc6{IMi{95{H#WlXvUDLH5n)RZW+bNCf z+#JgDZsDO(FT20xsZ`GWvyqOb=Cbi=Hn;Sp<(Az!j6LtJ@W`_*7BEZ2J%JVrNj>Iq z7g73*yM(zb6Q$3h^rU+Vr6;&F*@e4|^3(1alrOrgX${Skt#Hnp@wCA@G? zbelSLPM$@4bNF4sFR;)tQg%QauJ$?3f26k+`qLFn85tvUWQA&^9;mK&q)nttp?+vY`KS;T zqf*;=q>jp=8I@hj&3&QW)<$!q`Ou2AXg->Ylmi1L3n)RKw&LdhR7Lwrk}GiTM6?i9 zmX+wlbINETG7fXm!Xs5xA`_`OPFYCJQ|hzpkrA1Z)zWdoMdwo-##=qRFTB3!HN7D4 zyKTGG6ZVZ;?|;|}JbSn2dL7&Ewmji=o8DTnXydso+Q9~=D2FSstKI%u*L8$&hP4g* z8r=uVowkR*cChXA(UscLliikG4e3&6deA+l-R$cAioIP+CceR!J25x+&z7=k3y-?H zjiZbDFx{l*%C%)nTD6;AH}u5JKCZ~ZUh8Q4py_nHndUdSY`XR4V+~X%aL&qOei31| z@x$|6cFD*ZK884+WuU6EudXnE|q{`}|UeOozd39bb zYelt)w>XEEkRe6@!sopU!jC3h*!ZCW3qKYeArft6q=m`^`+qvp!P;Q`Lybg=t3T2q zB8*5m)E@zGPAfY`WccdHL{AH)<|73{qXMB}9a>Rt2kZ~Fyowr>MCKz6B1A(?p35r{ zi9+MSKV2KM;52I<%aF7myV?`sTH5O-@pi8ZPWxbU^9MKfx}m@8+0+oF86O@n0iVWe88nP*kVek$Vmm^Idg|7AyvzZ3KcBi6&FHKazOVS7uy~vkEBWv z8(nADi%V3JOemhE=DOJH)(V2|UlMZ8>2#=xSf*-|yVr}fXE{{8#A!K5tm7LKh0`QV zi1Kxomy$}rSCI{t@KTC~R#BJqMb!Xx&Fi2sO`J#hzf&ciOEpWpf^S@5Au8u&kPP6YmLv)Ksx9Y1UcC-i%g22E6spE44ZXBzxNBtS__O-zT%0l*3rbr%$+5DWn> z@}okiF~ArHI`YL)2~eXx2N0tcOA5dR-&0B6Ma?onlzFH}MWR`vJ`4KCs7l<)=6_Pb`@_Ytyk{HH?bm6fjmpjXmCCG&_ z&3%Z`1Dah~GQf=$t*ln`vyk)^O?(ff2}l0N8#P_LfV`Na*RisGiqA+*9BC2cI#j_I zBS>@l`&=W{cLWSL)FSYrt3wIXZW|+gtUiKo^vDEv8Ig5LX={&E*H}_UIrKnzq(fD+ zB4e3rA{8}|_&moSL^+of#9Wl;Z$3hb`OI?80hVr&K!kd>YlqvO%``_jC$w9lw`*4i zE5nujmF*hnH|*I}N#rY%Xluu*ZahRaQ{3L{?R7~RvB3mYFQ9>>_GVXKU8&hn46EH0 z`_T)(qnC&B{94cS0ko8CrtT6U4cbad<4DLiFx zDN>}?`)iHOK)MN1CqR|>N(hoL!Oz}cw5=JyQPAc1*pmr6B_!7;r>;!$U zIfa0eT9|?HNEu}Wd>4sOIaI-qYMVGu8)+jQCE!%>9r)>W2E44_wN2H+js9ig}^gX-LScv$69-0>qye3$*bwKQspwF^!`4M zH8%IRVd#n1k**oyO?(8Md93co>h`P%KGA6TgT~~#{}^opA~vP8Na9zm{L;|HJILd{ zwSE{EV48G1C_3&@@k^pN6~Es_hA$b`E?+Iau+Qa1lX%rixLCUIwni5Ug&Qh&g|1Y6 z@UJ&|A*==4u^Aoht)cDU4r%IRFRvfYfZ_UENzZ1lCm;_o7zCSS_emC?yuR_qvbz(s z3aHm|W;%N8+jqLZ!u`n&xXCJxnG&}{mh)n!#|B85v+omkyEdRN4d>0MQkf$FJ3g#^Qe%8p8K zc3LSx%_Q}a;PMpwED`w}bmuRjFy89fu$7y89l9oxp3)R0z7syqIouYW2So5iKr%2< zavT9>@IKq~p{1;W#{GchYiqCD=Nn7YXnZ*hX?>L-F56$d+RP$eeXOPwy{*X5C>LOv(szd+xG)+1#4j> zB~YX-%5gvr5^s*BL}!K&XzJE!RI&879Phk)Xk*lY2`peK0nUChJ8%UI0x0S z0K*R6Xj*ABH#sZsd8!^eDPbLR*$f!bn(T%7Cu({Axykt>^N&=q9p(<{yiD>l^0y3d z)kgE(<+KJR;}%pC%|i+loqq(rUpY<$u+cBIOOl;ZB7lLQ4K^Xg_Z`t^c4E~P62W0A zhKM8{V;esKUBEaMNlJlXMdly3o%UWP+TR3pe*(U`1Jf)~^Tzt!gle2%L8gNqf&cLs zH^(kPTFTbM5WEuNTYC0IXn(TryWuu`ogOL8ce1fT?-{1tn8uW<6h9LN>I7p<3Itw{ z9oZeY(xCv<$nqpNtK8hwVq7l>3G(Rlw>>A^b^1ZJ8s;^-dVcl7`Sa&!6yiObFAc(+ z`bc$!wnTg8UA-bVm5oB1PIPvRwYFlPUz^wi7}>6F^?DtQ|3RMz#bu%tKEWxtNNL4R zMb7jbRdo9PRi_hp8+IU?tFOBdrjX%Od`s-16a}lVqbf0vzEK)pw$}$2LxbY*hl0m2*;mosUKj%dX8#w zkXQohDWRhT1^}WjdE&d{w1@=~=V_%s_y^1*eXz?#Quo!oUW7gY?-8@2$afn7nz8hUWnaw&qh4WJWDnLo(kBn~K8~;EqUN0;d3_p`pOQ&hm(O?&$0Q)4@k3 z=Q!AboQ0f>DMIT9A8ChF4+k+EUfK?v($F!eSDtg@UIpU4d*C^Wf_}Kn(?fXt= z&y#L4_BNs!oSgvzeB1^I$0B^Q7kXJZ(eJ_^g0QFE-P_&r#D=|@A|*;1sgrWUelM*^ zATz=74g1|AH4KRx_Q%I*BZ((;61pK12CLXx0L0$Ca)s99T(M;? zolTB3oEd5O^pVJw1^-?$((q}3C>c7PW|fY4)|os$c{bk0_cA`1@%;qpPSU}MgK1am zlS`R$nc*5|C*zy_HD}7^k&c+hPw+^$h#DJo=#1nNx$|NkTodQ80SahKe|PHx;1Ff!PTuHe`AAl1$1a&LybW_O%4$v zb%J}bL3^JO3-JkpUwPHzbB2u*kmp@_7H)8e2?}|h!EavABT$846-p4Q+DD72|B@hk zUdQhWemnqh6msBz9Bd>Z;B$F_t@y}b&QqPg!c>SGNGqNf0X-~$1;4e&O5464ZrkrT zdk8W*2r}|DPb=SY<%3Bg{e70DVUdpc_zh$BzGD>k~{E@$N>6RN}?R9 z5x;{{@w@bTDSe%#)T?;CPZUP_moh(3{#+`XqNFLLAr%zCZ3V$_Slc@!{}uYbn}n)B zSJ2de804&l-lhYQtWgeV0lbMUSlflr07jS;3gQ2Sk-#GbxAm zIU36WpjsKBB^m1$C($OhpuAk$Z8H!HhGY&{+0;LaT2b`&sYemdr!ir;G6K(`Q(@mH zxGk?P;$I9FHSqvL;FJydxgb9q^wGpe_*|`>i7iMBc!~shyajQf)GSOuOn)e)Oo{&E z0r72$+9Oge5UEarHe#DPoWm=w%trh907!j7sFBZsUGGwZA?rd~69!pYtkJ{=hPR-TlU2I&(b-r_dcJKajmt65zmK0f*wX&ngvVjJMWW}Xy%SlSdmL10|mF<#qxmfGs2o^Dt6DY!80ar$G)0&I?Tee{wJ+^U`<9}>eM*rBWBUX2p#|C$i29v# z?~=QemBtPVbQg2x+%rFS?#!9*oH^4e6tV_>Z{2=>S$xhgeoPO8KNAmMK=R)Nk%km4 zV?NHsyvfhjyv4aaZ%cFDk(RXIGUij#ktyV9nU*ecS7u}uc}C`B9(h(4WD$8zj>r=7 zye!L6CfA3;%fIt&2-Ey#CJx;~-f` z{&zrb84G4;hSn{yih-}0stis*W`~Y6!&E!P&lWw0DV>tmn&}@!ZldmNU}{ra6UI#e z&ejfbGe(#WtX2D#2pnyp#L?zWv25xz+8F8?l%=G7+qzP2j!BS_v`^+Kw1p1y+#*Jfx<2rM` z)mf^y{Aa-QD`Qsd`c~85%|&)wHrFG^U#WM!$XRZ6>Oo{JT{{=0;?IgKOkb3JaqiXU zUa39%%5%>~#hCx{vvafaFI{;mGTU9W59&*6k)!JEhF5h}5d%>LkjT6oMCR*AmpjOj zW;>BsQzKL%PozYIG>*29KMLZCEdD8xH%nqtoE67SE?=Vc?Ef-g_EhvkJlrx9)hvJ` z)r2xaH65rXJzMk~Le*}GilOaWM#V6pxM5mjpFvqqASY$L3>8i3bfT=5wi9J_Sy@e} zX_}Q4iU(zNSy^X)ma@KWE?ax*>fEQUik3{#mnoYFeba+g^fXFYMa!(BiE<9B=q+*6 zij`E6YC}zngX&OQQ#5zGR%9-3A#KNN#p}UKQAX**NUTh!jT^^@TYsE7DOxpE0jXx< z?r1NmDv{$vs6Tar$T1?HA@T?je(j^g@w-ogn?zzY=qtgfh}GbPsi=HN4RR3BSI$zS z_pr6KL258%gcfrqI2%A}G_dTzDuHDd=TU=}`@EqoEyiO+J1Rxn(IS>i8Ys6>Zts_Q9O?yz;%AP<9vDl8dr9fc1I9AY*L`Dg_QsM;ddZ@f zWGvW5&N5+gW=;|2vXb{$$yw)_BgN_c^xnODJ*0b8Ta8oAhd?4580ST~B|tWiuZ>jA z0gTogQOXZgQ+koxt;?_0Tbo`aRw8jdvhW;bQKJS%dNN8OuWfd@Vy)#Z2QdfeZATd@ zQ_aRo5G&!{28c#c@(LAPLcQ6JYJ04AK4~)PWY%RF_>IaJ&kHZ)|5VU zPe-Y=X@sUWc8t5mj<{=FGOmt7TVcbXN7jxd#a*mF8J#KPF0_Vh=Xqnp)ijsZTqMr$ zOjhkE3(yU`nrtowk-g?^`;oocX|@5z-BvR|1!o<0q!kr7t6lf#qx@(x080DIp7LaE zxFnxcskQ6t-oa^c`^_%k6MO<}8yoYuu!XvWhoNmr#)#m6x8)D9+_V|{7|g1Uj|7C# zC!64~4MI2z(rS?U*r1pOP$3;qTc$u$G7V}<7Zg4aDEuInET48#K={VB~rjtykm`*dDVS1S9Lrjk_J<9YL)6XzH z&h%lXCqPF-*b!MGD{%F8m>g<%kvy#>sT7FraGyq{a`H0S4Gt{>? z-UUSEcAOxEvXjeDksz(@n*?t8fs$fU0>=s^1@!3Z0_M5MTprwrF36KmF8Gd9fLv*g z89L9B{JvzukdKZTVJXP$WRY^iZHDXLrFAR^`JFQBqIEwm3ZRk<$^7mT8*m!s2x3nRv6$4)>R!7h)TtGR>b9CvzoIM?8(Id|q zwZAFCoW@+OnRkRbe@ARQMsos?I!WL>MNaHkSV4<1ZC9xF82E~|n6Cu+)si--lt*8o z(Zf*i$mApP^le8N)F;xz+Q4RfY)+9uj!ii2@?MGvneZNM_#SPN9$DKS`8GW=Rz320 zdgK}P*hjL6Bkbvs-R+ST>=7vLTD^P7V&o+F+_luJRQOt7+82EHfPkxUJJk3Z6;ivt z;Hw7(49$Zz_61ZIee4UUw;=(w>kIZ?)feo&sxR2<2m&q~6mauF!QPm7{AjW-NXD@* z;NBjmPoo|_Lc;Gy7kevs{OC#k{rDoB-b}F4Tv}^;zQ0=@Ff~qHsrt4$ibj#u^4d|V z?g!glPfbIjCP-|rd-XPGJy?m-(rYNs^CPDo)Z6FOBvr_JtyZ(^d*{B=X>CWvypQl>}%>~WOvwj%LHB)%4z>y+MzQuv4lS~+mXVMNn7{A@Qm zZT8I~v3(HQ3sTju*A_c+n@$&m=KhyK@?`8_;fLElJ#EK%-Y$u0m_QSBn(I6+ zPQy>H2=hNm1@hWU)DF*i{7(t{zwC^1w zUR32{540EWZ@dG0@%RCI0rbY4{e1S~6x)km2ChGm*o*2fY%fm9&nEWb3G`cKdvWHV zy-?3V|JCzEULbOr$cse2NaRaIXd|j$Au>mV?1g%X$XAHGOym`i9)&P^6?(ivgw4e3 z#Jxdep2)8fSs-$i$k&O~i2NFnI*~;pOGG3QkH|8S21s?#o-p|CL9SY%N=+iGMAnF` z6WJ!h&h>zYy@xk7pkHmd(+Rq$_CU}%0r2@$pmPVk4*B9_WcNVl`>+m)V;YAG;CEx& z$8IQsZ{TfqKygL`A3LBp4}p&zP{9spRvvUf;kKd9paWWx z*{#Wf&n5C4_rCLR+3LGGo8-A9&&PR?T1~?>%V?bG&_f>18{%&}S)C4YL7v7& zQ9L;=#~T)$(TR!y*Lsqae&c0~og;HsPhsqZu&9g3N8m`>dL$Xk2>4PmmXaRfcm9Je z#W~(xz`I9u2`-o;r{r|f^Pg`#&s%9`;7qyif5q+??qLG7Byv32 zAZLUNBeRL)uoszIPekT6($+cT=h$7CcX3GUcG}G#4$}M%b-}xmNr^g(T>*L-xdMLx zk$(;(FV36ig?-PpvZBNWo=iLakBO&+`DO`cVtS%Tjrpee-*!s<5tr=>U`L#!XmYIo(5ZByJBS>^Esyovg+AKQ<6qKloCnTo3c8cl zE=M%zz-O~omT`tAx80fJP#e&UF>IeOn$h}J(+5^Hb_Cx;5ie!DI+=KNTw0yh`hDv% zZk(RBX+FYw8m%)`6NBHcHse~iK*FN%CtI$ZFiN_bIHgjNGVFHP}K ze0*~5p|HHzK>+?PfaZ?4Iz!?29b41+*ZzQlQ_dR7C6i z9a-(Fwf!D&z6T72v~lUo1vp6FUBqb4gBCIiCgp)WyT1e&Oc`8Wlq~mgF;Uvu^#7*D~A5L>=N#iWfrDZ)5b6h$jGl8q8 zWmeAsu%Xo9VLho2OKZmXJiTdZkDt<$9B=-2#I+%3{wLs3BYX(y2(?47F|}sn{F{M+ zn4C0Ff^rH$vb3TdJq78JAfq4BM{b&U+hLhobI{_m+QymN)RQ>>ntF=qVVv718aA!L z#s?dt8H7Z!h4OQaqu>1Jmw#~N?=Qy3Lnt1(2M3^+5z;z<&#ylJ5h~;NpBEc{_qX)> zmx~Ov)oi@$$Q6h;9LE7Ie{&rNvToOFOQld2+5U#YCB=Gu3#8eu9#!9iddq4k4F0YLM4BOHgy*#)B8ok4^v^F<^f zo3&=+o+s@e-Y6kZPr+}Jz@d^9z$PAhMWzxSj+IO&d=7Yb0eE*oXOVLL1?mBa=@Kq3 zID~dFPRI=7S&VW0SQDX)tAah8>&t;%9Qhm7l}SXQ;0TTBL&RY>j>EK-^|;Iu7NF0J z`WD(FMh86oH0`tML$n|H#O-s*`*Z4tXkYr&?SbLAXPW?CP3SW4YyucXIi7KmY^aZ5 zuOMUoU%{hP5(Z8s`82hQ(M#O}c1-}gfP*Z9*#~N6aY;isH;ucwDLu`zR+f22wBrb< zr+{S>jAb#dA^u1XEAXtOalKqEB(5Yum|_=z4_YtLWjYmHdXq7D$VTUm-^Z7n5)x2$Q zzOaW){5VipS3`w$71`82a$}g<_M%JzEnpKND)ohkiG7F{1D(zSozCtfS0v;<%Ba7f zVSb;;0Z%eZUQLV^HV#_;BI*zxl;9@Xh#k@XXP415#kgsk6`(~0@ucDdb6mX8)r)B- z3EmN&i%3~BYaRxE*rxg`yk_VE)q%i5V(g&xrK=aZWC;a0T)Tk7Ed*W1&^3(4wH>@C z-2C$8l{pvRwwN$MssWRuXa%7T?CfVba2QA6V*VQ4MU2I&(b-r_dcJKajNv_BxCDV!`%UW4c6*)kSK+#Vm+j8SDiR7PnD|xlzTrQVf z?k?wEN~Ag$Xe}91fv8o|qW=vj#{tR~@I%puJoU9N?PK$BpHifNyJ%1p=u1%)QNMHU zU2>O}-Pl2a?qbf&Ju`Fe+?g}qIdi5_C}a)%?mqa#Me!?!@l$#j{+W0{G4K^r9YRTfXNQh7!&E2F&lWw0DV>tmis_$0ZldiwfvHVxO&fOv zO15^WwBHERfwgSk7lETK)HvF_BNk1aMju0Mpf1(1LDSNHU|lf+cR6#_(-=`1h3nDc9Z4nh(oxoGQZo3h*x}6288vLkj8fg;AKMK+` zv~lfVC<5cYz*-pZ(#n8pv1rX;8t$0iHP`Gb5IbKLQQlu|sGzPqzjfQ&HY4-I$w;(N z5Pl>!wvGQfdg04gU-v!bUv2rVf8ldS2u#w!4iW zvKDThic<0LA`8p?}a=Scx3f=rp~m ztBM$gDu6`hr64ljM7q>Pj&z|Li4`?Q4e~@vL`ZXJ8~76-hHym||CE^*B~cN_O|D;{ zbs79JDD0@(-{9fCnW)$-N~B^?DyWzP6{Ba1orV*yXw8nmc(p^GM zO7|V8T1uxArL(l1D4oknXF}c5taMN!D4okn_s+*D-92;B+EF{_iE5`vl@xuMvWd_) zeY~n2LoKV?Ay&0S$wpP}zPMw>%BDy~p>D-tb*Qf>n!9Z)G8Z?HHsiJ8_28u#NA2g4 z6fMQ=L!*74q)du-fg3pz}tXNY``$RQ#}i115k`@~A{EJ_umL#G(Un1QuDGMhjZ((}uRRn2K@f zj1--$7BQz{px#Eky;J{oQqLT&cTf)n3d9x^23TSawtT!AwCyFSnmJ9-tUhNVc8Mon^9ipx@5mST=b0o;G$SnQ#(WLQc{XFg^u zL$W?H7H4NH*+fbfQj)P?6FG}O*J@^75puGM_gTeRnXVc&gQ03SuSPQU7QHHJ7R4BHCT(s*cCH>XX(Q1WRTX@Z^ev z;-G1ZFNulaWU609s}Y+>AG)KXU&cer2u*El8IO!D@yNJnTssVHg*Ad6SzDGAkFfq^ zI%bSV&>FIUr;W9ern$7{BXNRfvT8?JKyKjGWosdb>=kd*kL=}cs{<(RwOauiIICUh zwQB{=>bE@lD8E(-fYJVYc`_w|83HLAS5q6MTYeL)h4uhlDzV^2oL% zC@3TGw)_c}n>OPBfmzk@kpM9IWD_N9gAgWPT1`?P8x+$3Dx?Ey%M@rzra?{Vg2Iyl zg*U^T94I^*(7Y@#rwCe6h+FsCtRE9OH}-F;}T-Jp2^U>zZ#o*}n% zGAyIVq_)e{dJ^Tbwzyml^2;S{P%VE!nZ5uXQZTlP&I(3EsB)kCDYF zN>I9Osd=h=eZUNYg#N$VoT}CpFT?R}00(}}u_z(%dpPcP4;qk-g`S;Va+oF~U zmRbud9nbf-$A`?0(@?6hsScu3WVO9clxq0FX3tam!BOO#ME0uJ=zunYr6?`Ert&;L zavDLSb4txngS^*nw|c&J>YLs6W>j2kwc8iEotxeUKRfI2r(cLt3oW(K_O|z-)acx7 z_)!MAY)N1`Ox*en-wPsfDH3l+;zA8rkI~+0$zZ9LZnsu%T=CY>vD@w{Uy(9J0>Izu zNN*z&Z$;wUHFK5HTTu!h(Zom{IOZ^@X&h%Zo81okX_4663-bl38n@~70uf2XBeFfH7&TM(`5tBc7;@a?R|`$qi)yUz}O36 zQ5TVq!KJkISTdF|lyS=8F_!cgy%WyvbzK^i@cshcKc-7?#vC~#XJhHu?$_78!`o^9 z&{=cW7mHo9b6D$mcm1Lpi}w_7C*D}Jr_>!HcR`v+`2G9Y*coB%=B+{-pjzB@DHOGg zGG1pLhbkC%_9P<9>-n*_V7FTxFUrE5=(W2+yLCftQitr4f3qHR>#BK!-43+`u397E zmBg`V#Gyp4V9M)lF(}i^v8@WR^PqFJz zb8&d=bvvyf4&MA0iP*C}W4C}_MsC5Mp@n}O1a5(OTAY4tTXZfi;W$`gLr=#3pg0YV z%g<#oC)D@Qeq{HNJ4}g9KaGBP_t6fxc0GjSk3*sZ{dcqj&Z`t9j`bfpl}>aA5n(vz zvF^CgM;m&4NV{DXFX$g z@E)po5#x2q#OvbHy0lk6vLZE`m^BOZf1BvV#)3x^{Uf}4<*kdsY&gr$!OPK(iwkg?J(Yv)vo&St^}O+0YM>cSUPn93X*phapr`VFv2T^ z{<7_yI2tfLg5P%>6!&muTJSM!W9PSIJAfr4YBYxMu5nlYR zsS)0j-sL9m%CgUO;A%RAv=F^Rv<{~+9t){dL7-#W2rD>8T1ex`^GuN1O2_rm!kL{M6$3S>r@$n0PijZcfIAmt^iz87j)^xO zkhv8Hl3&s`PT!`k;3RD78Kwttil1)Uv=#v0mk^D>%|{g>-D^j zR4b^9Y=2GR;$pS20n+MJpH=~NS*N;GSG8VQZPvzszpd_@3b!CHwr_NAMFr^v-U465 ztha+Gj~kYzS6|{tS(J;P_|O%la0K+botnGU@Z(rwhDsbkj2r?+zcxuR#=5`MU2jWI z)%i45D>2if*Nt#)(I{e$`W+hS?-F@P zSa;#3HX-raMrXm>orT)I0gor;N;3Y)Nxq!3FGeZOPwlk~S^D;iJdJ;car=axhGpYC zh_Jm$XF^w)!bYfcLX`PmF;he33Y(2N7>EfQ$YvHr5q)hjg+GN_52w`+A#nuf=q#Q_ zVw`&oY6kqV7~?o}&JGaW^$cU&alT$8Y_nEu-1?;b7uF6Ua8JQ-;=rAf6wsz}2IW*z zhFc}mNjV3^yA8y z6S`bMBnocOq~1p*?87M-wsAcrvxEeYNyv8$5)iEerhXO)Sx9I@!q}&lkW0ppgM<$t zq4b$00LyXrHVwp@*5kmmX?jTqe2sRA_ z10pgP4gk>ZATDqS?PhUnH=}2H2FGQ7#X*g=l>(kkGoHoRhKM9Nu5edT3p^_%&LqK@ zVpo8VTcciK)rqn02~|z><6ly6&#B+Xa|~;VL$M2Me}t;YYTWSEA5)`00IBZZ#W3|J z#QjquC-Qxe7?GAd+f|nZMMCZ(jrswN@O>hCe90_%H!)(^ICA;d(Z*nWAm`B*CgMi? z2QrxE6l%(pfFwmA$CLA{3C{>xnFT# F{vUBAu=4-_ literal 0 HcmV?d00001 diff --git a/dwpose/__pycache__/wholebody.cpython-310.pyc b/dwpose/__pycache__/wholebody.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ce07696bda9d29d060f57c8dd6a46c71d800e24 GIT binary patch literal 1996 zcmZWq&2Aev5MFY*f08UauG2VeTDbpJ+qgiJLs0~20=FpAgApJ#3J3y-^^%fTUM*E} zWn0GTDIDE9N)|`K1o!f_^oem zS8sHvT^Yy40au#uPYY>mG=3R}=`<}>#xk_qK%_d4N74rIzNs`J=NJs<)*EX$+Oi=; zU|ppPyY^tgf=ALAxSVa8dV#F4t%ZPO=meG-nf0wN7^zes@{$o^{0Ce_;H#Gb!_+;;*9q* zxiIvSi7C{G=SoPkQ1ZPzy`QHeNOeyfqx;_Rf_>ZLsf8S*d0KFeF`G>QMndXeFhVaj z+#94t+O$p9ybW@?eo5Qy#ij&lzZh7(P`5RI88EwWN?>Y9&M0se7&UZD3XfU6Hthm8}QRSG;y-&~j~ z-}+M&KSuZr;S+?<5xzv|Abf@J1p&K>;j-_?6@cc9fq8FWpzJ~jKrPBwafo(uxa4pRwD{OQqr*(s{ zzDX2lZTnIM9PTU2a%duUJ(tJ)HA`bOLhAEWoYYsFq@mkp?F&XJZ9vKRsB>Q}mD=HM$3ySffuP|u zI3Fo7shUrV^WP?MR$Dawf9Sgi*aVgUy-{Z1hN5|^3gZe%O|7p zRBxlAiETws@**8c{T67t?Ya$A%@guleZcPOk`i-Me@r}WDS|4XI!$-#=CAWFbu`kE**wmf&xV_fG|L)BJC)VM30h_ zIIvGmk2&w!?zX>YH(qzran~K{y(ib1yNv*Uj_-8=KA4se( z4~R#wi|+vxQ8XoDc~Ze@+7ld3mik#B0*W#(ZDgScDfy8ouKY`){Fh8L6@5gPs@NqFp7Ioj6MUj4 z*Zk?or22(X)0l$xYWZ;Cu=imXPXH91(;1jJVYa8d5zV>cBRXfwpL;5p^O79XnKxsq zam8jtvYO`Cw5JHT3lAKRCn|DYJf4W{{bz?ioa=Z}CMJJo&1s^v6*u<}pFI8#`(h_P z((%dIBzakUIW?&{m+79#^Ye>4>k@T#FHPknHnEZSzkX0_^qJTgN9B<;R=${)x^T_$ zJ8wmwJI~lsY&~5i`nfI&h{v^y3vG)*QY6P_aF{2L^CXK(Jy2(;I5=BXVQ)Njp_ECUlu|Y^`#k_c zISUpc9BNyV8PESqoWEO=*iDP<1{_-etc7D&)OpEpKCwST4a5Y7&q%~@mlW)@-;*Hm)|uV;<9 z53e^sqXAK%+*jRQ!U%pHe%X*g2W}qp+1^35fCYoZqw}Pwu-7Y?Yf|9^RG7pHdw+Dl zUcY}=-2ke;X+MI)et=?3Y=s%Fu=W+!sp=qf5w;Mp@YN=Oy@{}m@FDW>ZvnSB_Xu6d%T z%VBgu5gXu&?llX#cW4C9VPHis*Jtuyl2(>s)5}Dii;fxVd_~pMvII;d4gv(2N{LA? z9sm&9p*v@!E zDP2p-%&19<_l}f5Poi`wv2_9Og5R~!u(B58v)Cor3PpXPVXNl!TkwB7{mS|IBpXld z7oZ^8_)oBtyi77}@rmdzC%^tfL1OF5iOn@8CS_SgwElnoUHEVr_P5x$MLqE}WFh3J G!~O=)Tdbb| literal 0 HcmV?d00001 diff --git a/dwpose/onnxdet.py b/dwpose/onnxdet.py new file mode 100644 index 0000000..4fab0c0 --- /dev/null +++ b/dwpose/onnxdet.py @@ -0,0 +1,127 @@ +import cv2 +import numpy as np + +import onnxruntime + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} + + output = session.run(None, ort_inputs) + + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is not None: + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = final_cls_inds == 0 + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + else: + final_boxes = np.array([]) + + return final_boxes diff --git a/dwpose/onnxpose.py b/dwpose/onnxpose.py new file mode 100644 index 0000000..72c1b00 --- /dev/null +++ b/dwpose/onnxpose.py @@ -0,0 +1,360 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import onnxruntime as ort + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for RTMPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray: + """Inference RTMPose model. + + Args: + sess (ort.InferenceSession): ONNXRuntime session. + img (np.ndarray): Input image in shape. + + Returns: + outputs (np.ndarray): Output of RTMPose model. + """ + all_out = [] + # build input + for i in range(len(img)): + input = [img[i].transpose(2, 0, 1)] + + # build output + sess_input = {sess.get_inputs()[0].name: input} + sess_output = [] + for out in sess.get_outputs(): + sess_output.append(out.name) + + # run model + outputs = sess.run(sess_output, sess_input) + all_out.append(outputs) + + return all_out + + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for RTMPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg): + h, w = session.get_inputs()[0].shape[2:] + model_input_size = (w, h) + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/dwpose/util.py b/dwpose/util.py new file mode 100644 index 0000000..94d9d9d --- /dev/null +++ b/dwpose/util.py @@ -0,0 +1,336 @@ +import math +import numpy as np +import matplotlib +import cv2 + + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_body_and_foot(canvas, candidate, subset): + H, W, C = canvas.shape + candidate = np.array(candidate) + subset = np.array(subset) + + stickwidth = 4 + + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [14,19], [11, 20]] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [170, 255, 255], [255, 255, 0]] + + for i in range(19): + for n in range(len(subset)): + index = subset[n][np.array(limbSeq[i]) - 1] + if -1 in index: + continue + Y = candidate[index.astype(int), 0] * float(W) + X = candidate[index.astype(int), 1] * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, colors[i]) + + canvas = (canvas * 0.6).astype(np.uint8) + + for i in range(20): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) + + return canvas + + +def draw_handpose(canvas, all_hand_peaks): + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for peaks in all_hand_peaks: + peaks = np.array(peaks) + + for ie, e in enumerate(edges): + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + x1 = int(x1 * W) + y1 = int(y1 * H) + x2 = int(x2 * W) + y2 = int(y2 * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas, all_lmks): + H, W, C = canvas.shape + for lmks in all_lmks: + lmks = np.array(lmks) + for lmk in lmks: + x, y = lmk + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(candidate, subset, oriImg): + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + # if any of three not detected + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + #left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(candidate, subset, oriImg): + # left right eye ear 14 15 16 17 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_head = person[0] > -1 + if not has_head: + continue + + has_left_eye = person[14] > -1 + has_right_eye = person[15] > -1 + has_left_ear = person[16] > -1 + has_right_ear = person[17] > -1 + + if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): + continue + + head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] + + width = 0.0 + x0, y0 = candidate[head][:2] + + if has_left_eye: + x1, y1 = candidate[left_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_right_eye: + x1, y1 = candidate[right_eye][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if has_left_ear: + x1, y1 = candidate[left_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if has_right_ear: + x1, y1 = candidate[right_ear][:2] + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width)]) + + return detect_result + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/dwpose/wholebody.py b/dwpose/wholebody.py new file mode 100644 index 0000000..ddaec2b --- /dev/null +++ b/dwpose/wholebody.py @@ -0,0 +1,58 @@ +import os +import cv2 +import numpy as np + +import onnxruntime as ort +from ..dwpose.onnxdet import inference_detector +from ..dwpose.onnxpose import inference_pose + + +class Wholebody: + def __init__(self): + device = 'cuda' # 'cpu' # + providers = ['CPUExecutionProvider' + ] if device == 'cpu' else ['CUDAExecutionProvider'] + + current_directory = os.path.dirname(os.path.abspath(__file__)) + print("This file is located at "+os.path.dirname(os.path.abspath(__file__))) + parent_directory = os.path.dirname(current_directory) + + onnx_det = os.path.join(parent_directory, 'checkpoints/yolox_l.onnx') + onnx_pose = os.path.join(parent_directory, 'checkpoints/dw-ll_ucoco_384.onnx') + + # onnx_det = "../../checkpoints/yolox_l.onnx" + # onnx_pose = "../../checkpoints/dw-ll_ucoco_384.onnx" + + self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) + self.session_pose = ort.InferenceSession(path_or_bytes= onnx_pose, providers=providers) + + def __call__(self, oriImg): + det_result = inference_detector(self.session_det, oriImg) + keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + keypoints, scores = keypoints_info[ + ..., :2], keypoints_info[..., 2] + + return keypoints, scores + + diff --git a/lib/rotary_embedding_torch/__init__.py b/lib/rotary_embedding_torch/__init__.py new file mode 100644 index 0000000..3a2cfef --- /dev/null +++ b/lib/rotary_embedding_torch/__init__.py @@ -0,0 +1,6 @@ +from ..rotary_embedding_torch.rotary_embedding_torch import ( + apply_rotary_emb, + RotaryEmbedding, + apply_learned_rotations, + broadcat +) diff --git a/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc b/lib/rotary_embedding_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53268779dbbcf819fcaca925cfa34dc68e6132b GIT binary patch literal 376 zcmZvW&q@O^5XQ5ADn%%I^U@dCgM9!gq7)BaJSclv0z-C_)<}|BlGL&f;xmXZKz^B-d^7oQGC3tNem~!rZ}I+0#sA1-aTAZ5MsUJ2MRe9=I&X4aG=(mk zG9x^HB31F7v{f0-S>LMxEP7@K5V{R#kGuR%w}NwNR^ehVQG!_`_)h928W+Z`MaDbk zAF5Gv`L|*}2Oe$rJUiaD>3drwlnJAV4`NSJp78U!xT3Wa)=}x?3(^JZMHI--*Hh>$1(oNjES>>0yQBtB=#)Gw_Q{aq}D``M$d>l8#O>rf ugXab{1jVng;U0sU_Rv1TQhC1pQW% zzco7NpX0haTpTU=OOUI|dt&%)Ra15Ro>!MtQ!P9Z{<3PRb7)zKJhiBnj)ZzeEvuC$ z;!yaj>Q%L>E@0G!cMY|sEk7mA~H}q!}i-;Ur6XA!MtNaa=gjAWkQv;@m@>go-+XEQ-@ar-kzY z3s|bEqfqNG-i!)IN8>1@l`7vJWRVVsJwbb!en^)1-lijWA>=sdiMVS#6C2NH!{!ge zp>Cmrix*FBzj^;&8tL@@ARTNc_irbo{@%Uc3LXt&l{`v=aiX*E;V_zY-0w}&EExrH zqM~%x7~G2o--rjJFpKUF2M_OSoQU2FqS3=hsX@FMWQp!={Xfm?<2@Y0W1LR9fNiC; zdspZs{PakkeMmleGkSpx?{+A(lN)!99B-&;dbga}mHHfNg;1yOsu!uHqN08{8D>=B zXSIYr5Bd40_?&s5iB_^+KFVbX^uRwg4rC^_WiE4}#F2z8$c<+>Ry7!PO-;T~*!;SM z*@K$QqB2#dUPYh3bfV@+%@jSqMRz(TTfvN!w&04U(3c>b(KT(*#+Ntv76jU$5=tub z$UK0&EtI9~BXM6I2y9z!o4KsvtooAM z=y*e+H64PcO_d|jIUC25x|Xh^vv4NyLzs`gKovX9!rl(fK|C(x{$z*FEVGll2zh!hb0iOHdF`p2+s80LmS6I_dWZy+IUbopi7tb+V*Wo_J>vcf#SYlWj$vagq)g#X9S(;%>FDRkk;d z3Mhj67#!m;Eb zIgAO(WcQ(V3nN5M`k|;$*?6%4ues6b&-|iadWt@6xemGeVv(FHo{hR5!{s!YzJ3)2 zFzie8V>1ZqDz%eo>jo9{28bP1Rw0EQ{Y7e~8GelpKj1)-&L7oEc~nD0k{H-?GW{!< zm=bQeq7Hv+Nv)k^U6Tk>Lwv__8XhvT9bXs$F&PSE;)AbE;K!U3GtC z^aaP-hHpkz8Ep|;ZHl&tyN>!5bzUta##Yp8YDKL=a@7rWL9Icms@K&;bqSJ(XSC72 zu6{}V@)O&ysn4k^m|u??%2A&`64U-KsH^JQ6B|!x{s_SBFF+%26bs+uVD!!`7Ht&P z5P|f0>l!ehUi&$P%sbC1w&?2+KPc+c7{`|UqESZSPTtu_;;5){YtT>hi!0Lr{DW(> z;p9>6Q1&GNl9gFcZGaKx9Gd_k{H`&eJ#~&zqg5^% znDYwe*qHMU##b=UR_H57ADfxmrZF{)aWLk6jH%|%v7j{$nt1~|@BmIVRZ+&V#6C#| zsF9r2w;Opguc+z~93+5*a)C#9Sf{0|W5ILkXg`Sy&rpc&XWp}!i?*?1s0I4P`z-(_jo^GaCtu)wEoUnPslf`(?zT2c z-@1j9XBBHwd114*3U{U}=soOh!8J+*YfNhlTq&4S1HLsG=qL>~b*O-jlhJ7J`Trsa z=`OAOWfYr#`#1FY_gkBN{`}9a?kccvGLCfN20*q!76keZ4JU$AhDawYsh!;)CSg{% zLomW|*o!nld||=r(&8NNCK&X=$kKFxKv-0>(KrAZ(~%~kS5)7}I->}jP~W6_edg-) zK7~`#kYb4t_+bS5MZk?h#jv|DB}Ks$o{PmnjCIDzkm7Pb?BOYeJ4vJ9F&^vLv@Z7CLmEhDWtTY<-negau(reNSgBEQ?Sy;HE@G&28QH>{~^04*q(x?wn z_zU!e;Pe6dKDD-aZh*t5Ff!ed_G#8VD(Hn-0U!$1Am|OlG{r$P&n2d(-$!v9-;^NU zYqf1hy3!G@xPk8_e4E0Nby*iJ>BvRPv6`Z8Il{3lprR#S#n`%NTKlVWaI!vA>KSH6 z_C;SjamK>8!Nz7BRqF4cvv5@O!$B_s+&@5d4zG0|ZGVDqNr>hC^8Ay_**!9n*@xtl zVAe(j#HI{TV1-8}B}VWEbK?%o*@I1{VBMevs4eJdlPWrR2(}PuaM=f#Uzj`5UST6= zm_+j%{66~s3}4=0TkKz$-{7fz(pj8gYP89djS_#za6A-A0Yl;l(6+*y>>_PfBAMdT zVA6n%5wFzWrGjh&Ve$dVzeWWeTFEV(q~A}YpCb&l%_w6lhn4*W2J`0E#Qw$k&A&A6 z%+4ufIQX(^{|Ey30>E!ufE~~;5)S}2xM6w@tP#wPIN~-ydYgEtlh2r-ec*s?8en;q z+|C_>^&<PtjsBNM%dYtKyk%8zqGfBPIImz`>DCk5BdItvt?Xhl1)*C2Z)Z>mK?h|)zb`&G> zGs)hof0f49I3d$Npz(wWB~+d+`Z~1@Q4|&I>rtpxNspKrd8IWlC!*DO3+3XcE6{#G zV58jxIH)shnjgh`bZl0#pJT-2w)=%EEJ8LL-){Sq&B=U^gI}uosW(61~n6_Qh?r+AWaS6+AlDG*=l8uSEEPM?p}D{XVZ-2j7?B-K^ZNh%1WF!q7}y-Sn2(pt{1xnZ-{s60 z(c$#QN&hxFpgH#cdyvbZK&_h8j$79#*1tzp1{PL169sZX{e2YO%1fbvb-Y6p>*-c9 z8LHBCc``DC&{90%$QP z^pyptb7B{* z&=3okM`>Z-PKHVBS8=B`P2xxjgR3;hu0xm##7i(6_x~k=wZ5 zvY}tt4?8JpaMJhSq;&-+rI{}~D(jorynd64-=KnAmhY|fdyu+~G9%-Rj&oerFm;uV zimV4W!doS_^4-=O)b|<{4Jz1IozL=B{Rh;4lZrXtBx7-~uau70ap9StI`B%@lplkn zO+9oO&f12OEGP@!xiYV+DVglnH#A=re4QTj4JydE@7&Qq)#(KST@>B<02khz-o8Wa z{F-oKy1t_mLYl?D$X-gJ{QrbC2j%*J#>&&yZi{XTVr23`P3L{kQy1F#9y2!?MB-#oK|54iFTM$h_t-LL9`M)dcajeN}G`Wx8EZce$6|RS)aru@m zb+dsH!0H+b{%>0XeB$sx)`okLy)phprHPN1F@+f?MQ*<5<|_@p5ari+$?*8Fm;<

-X!gys2hX_E!-Kb0vqbmwQxMP3(Eu@Xt%xp1NBL@-T(jq literal 0 HcmV?d00001 diff --git a/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc b/lib/rotary_embedding_torch/__pycache__/rotary_embedding_torch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37902f15d94fc6173a9f62b30d8a8ee38d8e4c6a GIT binary patch literal 7411 zcmai3-E$<@RquQIV|sc%G%L+YyIOm9@Yn7bTgDCnVitRYUmeu^+Mw>&xxUn&8#MgJ z;EI2R%Wi*uu;4EsUsm2D!*8mJs^WcBJ)!Dq?vd~pRYP4t%~Is4d9`pX)RSsaEjxT3!FTp_bK+M~1o)m5xk*S(y)w_N_D8V_W2vFph_5nD&No(w4d0>E&jB zxUq`7^ZP?J>PNX(ydU(E6dl}flny%~@>VY6ICr96oQwwfm32J~6@3WOC{Bht$(?sO zgQ3bg3bhX7jVO0?v=fCiQt3f2jda-W2wKbdp}fTJ4LWicnG6R#5)X~XV)ZdC*!(c; z>jpk>^6B;)udcn9L^@gPCB3cTS{R3&VXVS0twr&N$=Z8y@2hcd5T?=EgWmdDf*D`j z2|L^2MwF~+Y+UaL(O^ANsuynr=}>nzU;5(em7RU;bq|M-%whI~7HJ5*fV9JTD&wL2 zbOW@`6k9kFx|$gejSL&n)V*EwY)gFwrCg}sny+Pa5r1H7V+s(}mq+$A{vDEZ6eDfEk z(u~wZp0i7I$8FLj^hjt4hH!_Tq%x1qL*%!FvXp%+*5skU zvgMYU$qLS@E4f}C7CoFJ+D2*hpt-WL(z>Hsy4`y5wH6hSKchp(4u#fq2$~jEw8Y0r zJMPrBbQK?cXB4mZVx_O6klRTLwP>5U*vsYK%f8%uWp=+6Qj9bTHp71RN%Ud$G(=rA zL{sReP;(#im!dda??>vVDC7CB*UvsDiO~T}9-7_3j}PSEk;%lCacCojthS}TgPhE) z>&Bsz8JSao1$-<~qN9?EawW*Mv}7E*S&79-!_yvDSsa@l$E}3{J=UAe=eW} z0H#cSEW2n+C5+%+D7iETeKVoGhh+%D%S2|Jn5Rb5$jqjZS{Z-^HTI+i^NZBkDrIG~ z^tvVl>Kx0%N>({>Gy4>(4q=^<^jRVY^Ih2eNJlN`Abi#y)k|6cpw_V4N=Dso zuhWa-w3YM@qE6#jUX4Z>5`2Yi9@>Ca`L)u!!5`+*axSPLx|Q#)TX=Lln%; zFkzA8X3~pU#CjPk$t`NzuI6@EN4rTb(!2u6>okBSM0@Do2?2SfB-#xyoW4VSxYUng zzl8NrwO@{6RrK(kH0Vd&6!sJ7-o|E{o5N8$t2&KvXQ#g(oFmKA7&jq3HKZdw;fXp@ zSKg7oE}juhslSXb9X2UygNO1{BYG7D7?BBBxtT&|pe5!JOevKxEO|r*V`cc2kZ}$*`r!2j;arb5d(Xs-G2)m6L!@AnX2$) zU1+N7;CJbvCEyXs-yw|*Uns+u>bCO8fqo*tZFGgIsrsXuZ$_4yQ;kP}XSnSvY97wi zQFl~RU4`Q;sdnK`eOEm{cBWRI*NBO7znzYQ_^bCAGFKKBkgpEoDIXEnF_@Q-JW+rW6*<{5dy`)A}! z^s~smdr&LfX)9~34&&(HnX&u)liSqs_6x1sJH5TAPZ@YXD0Q6PejmQK+8HJ3a1ay< zU*=cPg{P}y)X77Rs)Qe`RV163IyH`r6LD$?BeSv!lm}{d&A4GK7+*0CtqkfFNGO?v z3cixb!+IQN004kj7Z7Bn)`<=HL7!6-u!Hw9L)eLPiV}@-!$6-`(Z@!gzd-vE`q>KK ziq@xQ>Ncs(BHB1;bAmSI%sCY_#$i3HVFe!GsG>^BIF(o@~Rj@XarY_KG__vZ82RUAPttNce}A#SQObb3tE%lbDM?f zyAx4DjA3UJ22qHyHZCz-C2ir?U{RxiGkC80HVwk^991)^$6Mt zHyeeDz&4*7Q=woC%K3aR#yC5}K6&GA*ugBhJ4&Kp57Xx!pGmMDCQ)968Kk`q=G`2s z0+JfJl|=n6(;WReZSx+j$A)uBL*E^Qdwj%xIUbE|Mz2z$pl;*WB|%Nr=txj=wKZjA<|&5w5Z(Nj`njK0l_GmUEl+J|112)-1*@8 z>;ljAlg{E2KchuX7fP%l#qp5;1Qdy5z}pftu^Y5piAaf0gQ)@*Mr=~QMG5H!oFgCK zm#BgctzZ<+u-i?dU!Vz{8c!B~V|MWu+Fe>XIgTl_{1yse2>`$?0Zu?~ zsh(0*)zZ5mrHps<88QTH+! zg#~EdNy}MT*_fl8dc+1XtEE8K@x63x;gsCfgOB%zqt*bfs0B-HeXybnGQA7Y^?+I- z2oocW{P zZ%re-1Ol2$LBEd!g7jE6l-F+}`Lr6h6fvH-^>RxwB0rXFz4{euU(r{o={G1LOemo8 ze9-5q?oW~AC9LbCP^*FzeHm44uh0f&LNppLp&0yp1X>ShYqVPc2gAe^EAqvegBxl- zx4%WFgjyK@iAA|A=K&QiXwt!Rv&!?u168wt~OEjMq1Ny;oA$M5_he$e@yewb_n$Ej#Ve2!ArZZ0vdHvG zAlRQE5Oj#O$~uHzp^XFNv4JGO+$kE7&d4zml~{u$Xd$9aXbVsYX!K{mAXizK$0ao4 zTsf^kV+8aDgcJJWaarQ4vw(GRT_y8QsWOX5Yn(yx`_^W(O`u#PAR2H%%HlG9WvaIg`ExRB){(!Q)aj ziuFUPV`yQaU#B7&p(eJ{E?saAEM-CmP}t4ksILm+<<5u=h97}PUzRm1)W4!f@*EPw z@l24z&%qjhE~WdqH0OUN?RjZ|C@$i{igH!?i>zo0TbCH4hXF}+;He_{F)}H*8pI?w zj6*X8g*XylGu}VI{T)mf$1M38>uyTo!Yv;T8NjbB7@iY1P)lY_H8xC{#X^S8B0wy+ zAs{4-i0BWH4YRfG8fke$WbThZ!M-Qnfh)<(6Av_`d>!B}jNK;8*~FEniLt9NXQCYh z@b0SqDpp&F>35K8OCJ|~(N0o;y(q}QHnpMq!gdYeP#>%}#|6)5fXiuc&U^HgMfEo2 z*b?8O9Gh{BT<*Tv@Ar1VS91rq7q}bZ3sx5G*yZOIBALXu-x&6Xv0uh*);N$e=?0Rn zTyqnok8Rp_>2rvyX1^XsA8{_=WzC2@pCfUhVK^S`zcJ1j)m=md4Vd#HnKq2uMTrQ( z!L3=No*6?wjX}Wr$$&4}F7W{yKepY!M+PD^hBuLlD&B?Vav3an+hiO5CfhLhrv7GO zsc3CM@O)p1I+QNjFopKGB!`{CI$gTUBy8XEPDN(pe#?fKp+;=2D8X9)2G&|tuvY4M z(PCNb0*lnYLCJ4YGQGIczl&nKRs?4p+_ACVqek?puvi2@xDwthV3sep*psuj<{K_H zTIY+5Ru5?u4*O<|lyt_yx)K`3aABe_PnfCceg+qO!BiJvxlKfU7NmBtG%K?)vusya zHHD1&w<)+J%WC@uZM8e1`)MpjqAs`aZ=+4V=FsG`7z2El|7o%k$C_+GlTrBfqOLPi z;Yv8zDfZ~nXI2|RPj-<${jY2Y@L}{P9m*YCN8>+Fnh19hNpN&Q|3l&HYreDKdrscQ s1pwfqqEpJ5eqn#u-kyP2#8@L*{2J^34uIbvI<2txFZF=7S14AF6zW@LL literal 0 HcmV?d00001 diff --git a/lib/rotary_embedding_torch/rotary_embedding_torch.py b/lib/rotary_embedding_torch/rotary_embedding_torch.py new file mode 100644 index 0000000..8937875 --- /dev/null +++ b/lib/rotary_embedding_torch/rotary_embedding_torch.py @@ -0,0 +1,291 @@ +from __future__ import annotations +from math import pi, log + +import torch +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + +from einops import rearrange, repeat + +from typing import Literal + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +# broadcat, as tortoise-tts was using it + +def broadcat(tensors, dim = -1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim = dim) + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') + +@autocast(enabled = False) +def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2): + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + out = torch.cat((t_left, t, t_right), dim = -1) + + return out.type(dtype) + +# learned rotation helpers + +def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r = 2) + return apply_rotary_emb(rotations, t, start_index = start_index) + +# classes + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent = False) + + def get_seq_pos(self, seq_len, device, dtype, offset = 0): + return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, scale = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset) + + freqs = self.forward(seq, seq_len = seq_len, offset = offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, scale = default(scale, 1.), seq_dim = seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0): + dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + q_scale = k_scale = 1. + + if self.use_xpos: + seq = self.get_seq_pos(k_len, dtype = dtype, device = device) + + q_scale = self.get_scale(seq[-q_len:]).type(dtype) + k_scale = self.get_scale(seq).type(dtype) + + rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim = None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype = dtype, device = device) + + freqs = self.forward(seq, seq_len = seq_len) + scale = self.get_scale(seq, seq_len = seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale( + self, + t: Tensor, + seq_len: int | None = None, + offset = 0 + ): + assert self.use_xpos + + should_cache = ( + self.cache_if_possible and + exists(seq_len) + ) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + if should_cache: + self.tmp_store('cached_scales', scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + @autocast(enabled = False) + def forward( + self, + t: Tensor, + seq_len = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and \ + not self.learned_freq and \ + exists(seq_len) and \ + self.freqs_for != 'pixel' + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache: + self.tmp_store('cached_freqs', freqs.detach()) + + return freqs diff --git a/lib/simplejson/__init__.py b/lib/simplejson/__init__.py new file mode 100644 index 0000000..807b9da --- /dev/null +++ b/lib/simplejson/__init__.py @@ -0,0 +1,562 @@ +r"""JSON (JavaScript Object Notation) is a subset of +JavaScript syntax (ECMA-262 3rd edition) used as a lightweight data +interchange format. + +:mod:`simplejson` exposes an API familiar to users of the standard library +:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained +version of the :mod:`json` library contained in Python 2.6, but maintains +compatibility back to Python 2.5 and (currently) has significant performance +advantages, even without using the optional C extension for speedups. + +Encoding basic Python object hierarchies:: + + >>> import simplejson as json + >>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}]) + '["foo", {"bar": ["baz", null, 1.0, 2]}]' + >>> print(json.dumps("\"foo\bar")) + "\"foo\bar" + >>> print(json.dumps(u'\u1234')) + "\u1234" + >>> print(json.dumps('\\')) + "\\" + >>> print(json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True)) + {"a": 0, "b": 0, "c": 0} + >>> from simplejson.compat import StringIO + >>> io = StringIO() + >>> json.dump(['streaming API'], io) + >>> io.getvalue() + '["streaming API"]' + +Compact encoding:: + + >>> import simplejson as json + >>> obj = [1,2,3,{'4': 5, '6': 7}] + >>> json.dumps(obj, separators=(',',':'), sort_keys=True) + '[1,2,3,{"4":5,"6":7}]' + +Pretty printing:: + + >>> import simplejson as json + >>> print(json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ')) + { + "4": 5, + "6": 7 + } + +Decoding JSON:: + + >>> import simplejson as json + >>> obj = [u'foo', {u'bar': [u'baz', None, 1.0, 2]}] + >>> json.loads('["foo", {"bar":["baz", null, 1.0, 2]}]') == obj + True + >>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar' + True + >>> from simplejson.compat import StringIO + >>> io = StringIO('["streaming API"]') + >>> json.load(io)[0] == 'streaming API' + True + +Specializing JSON object decoding:: + + >>> import simplejson as json + >>> def as_complex(dct): + ... if '__complex__' in dct: + ... return complex(dct['real'], dct['imag']) + ... return dct + ... + >>> json.loads('{"__complex__": true, "real": 1, "imag": 2}', + ... object_hook=as_complex) + (1+2j) + >>> from decimal import Decimal + >>> json.loads('1.1', parse_float=Decimal) == Decimal('1.1') + True + +Specializing JSON object encoding:: + + >>> import simplejson as json + >>> def encode_complex(obj): + ... if isinstance(obj, complex): + ... return [obj.real, obj.imag] + ... raise TypeError('Object of type %s is not JSON serializable' % + ... obj.__class__.__name__) + ... + >>> json.dumps(2 + 1j, default=encode_complex) + '[2.0, 1.0]' + >>> json.JSONEncoder(default=encode_complex).encode(2 + 1j) + '[2.0, 1.0]' + >>> ''.join(json.JSONEncoder(default=encode_complex).iterencode(2 + 1j)) + '[2.0, 1.0]' + +Using simplejson.tool from the shell to validate and pretty-print:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 3 (char 2) + +Parsing multiple documents serialized as JSON lines (newline-delimited JSON):: + + >>> import simplejson as json + >>> def loads_lines(docs): + ... for doc in docs.splitlines(): + ... yield json.loads(doc) + ... + >>> sum(doc["count"] for doc in loads_lines('{"count":1}\n{"count":2}\n{"count":3}\n')) + 6 + +Serializing multiple objects to JSON lines (newline-delimited JSON):: + + >>> import simplejson as json + >>> def dumps_lines(objs): + ... for obj in objs: + ... yield json.dumps(obj, separators=(',',':')) + '\n' + ... + >>> ''.join(dumps_lines([{'count': 1}, {'count': 2}, {'count': 3}])) + '{"count":1}\n{"count":2}\n{"count":3}\n' + +""" +from __future__ import absolute_import +__version__ = '3.19.2' +__all__ = [ + 'dump', 'dumps', 'load', 'loads', + 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', + 'OrderedDict', 'simple_first', 'RawJSON' +] + +__author__ = 'Bob Ippolito ' + +from decimal import Decimal + +from .errors import JSONDecodeError +from .raw_json import RawJSON +from .decoder import JSONDecoder +from .encoder import JSONEncoder, JSONEncoderForHTML +def _import_OrderedDict(): + import collections + try: + return collections.OrderedDict + except AttributeError: + from . import ordered_dict + return ordered_dict.OrderedDict +OrderedDict = _import_OrderedDict() + +def _import_c_make_encoder(): + try: + from ._speedups import make_encoder + return make_encoder + except ImportError: + return None + +_default_encoder = JSONEncoder() + +def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=False, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, + iterable_as_array=False, **kw): + """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a + ``.write()``-supporting file-like object). + + If *skipkeys* is true then ``dict`` keys that are not basic types + (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If *ensure_ascii* is false (default: ``True``), then the output may + contain non-ASCII characters, so long as they do not need to be escaped + by JSON. When it is true, all non-ASCII characters are escaped. + + If *allow_nan* is true (default: ``False``), then out of range ``float`` + values (``nan``, ``inf``, ``-inf``) will be serialized to + their JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``) + instead of raising a ValueError. See + *ignore_nan* for ECMA-262 compliant behavior. + + If *indent* is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. + + If specified, *separators* should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. + + *encoding* is the character encoding for str instances, default is UTF-8. + + *default(obj)* is a function that should return a serializable version + of obj or raise ``TypeError``. The default simply raises ``TypeError``. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If *bigint_as_string* is true (default: ``False``), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. Note that this is still a + lossy operation that will not round-trip correctly and should be used + sparingly. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. NOTE: You should use *default* or *for_json* instead + of subclassing whenever possible. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and not allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not iterable_as_array + and not bigint_as_string and not sort_keys + and not item_sort_key and not for_json + and not ignore_nan and int_as_string_bitcount is None + and not kw + ): + iterable = _default_encoder.iterencode(obj) + else: + if cls is None: + cls = JSONEncoder + iterable = cls(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, + default=default, use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + iterable_as_array=iterable_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, + **kw).iterencode(obj) + # could accelerate with writelines in some versions of Python, at + # a debuggability cost + for chunk in iterable: + fp.write(chunk) + + +def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=False, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + bigint_as_string=False, sort_keys=False, item_sort_key=None, + for_json=False, ignore_nan=False, int_as_string_bitcount=None, + iterable_as_array=False, **kw): + """Serialize ``obj`` to a JSON formatted ``str``. + + If ``skipkeys`` is true then ``dict`` keys that are not basic types + (``str``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If *ensure_ascii* is false (default: ``True``), then the output may + contain non-ASCII characters, so long as they do not need to be escaped + by JSON. When it is true, all non-ASCII characters are escaped. + + If ``check_circular`` is false, then the circular reference check + for container types will be skipped and a circular reference will + result in an ``OverflowError`` (or worse). + + If *allow_nan* is true (default: ``False``), then out of range ``float`` + values (``nan``, ``inf``, ``-inf``) will be serialized to + their JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``) + instead of raising a ValueError. See + *ignore_nan* for ECMA-262 compliant behavior. + + If ``indent`` is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, ``separators`` should be an + ``(item_separator, key_separator)`` tuple. The default is ``(', ', ': ')`` + if *indent* is ``None`` and ``(',', ': ')`` otherwise. To get the most + compact JSON representation, you should specify ``(',', ':')`` to eliminate + whitespace. + + ``encoding`` is the character encoding for bytes instances, default is + UTF-8. + + ``default(obj)`` is a function that should return a serializable version + of obj or raise TypeError. The default simply raises TypeError. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If *bigint_as_string* is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If *int_as_string_bitcount* is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, *item_sort_key* is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. This option takes precedence over + *sort_keys*. + + If *sort_keys* is true (default: ``False``), the output of dictionaries + will be sorted by item. + + If *for_json* is true (default: ``False``), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized as + ``null`` in compliance with the ECMA-262 specification. If true, this will + override *allow_nan*. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. NOTE: You should use *default* instead of subclassing + whenever possible. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and not allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not iterable_as_array + and not bigint_as_string and not sort_keys + and not item_sort_key and not for_json + and not ignore_nan and int_as_string_bitcount is None + and not kw + ): + return _default_encoder.encode(obj) + if cls is None: + cls = JSONEncoder + return cls( + skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, default=default, + use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + iterable_as_array=iterable_as_array, + bigint_as_string=bigint_as_string, + sort_keys=sort_keys, + item_sort_key=item_sort_key, + for_json=for_json, + ignore_nan=ignore_nan, + int_as_string_bitcount=int_as_string_bitcount, + **kw).encode(obj) + + +_default_decoder = JSONDecoder() + + +def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, allow_nan=False, **kw): + """Deserialize ``fp`` (a ``.read()``-supporting file-like object containing + a JSON document as `str` or `bytes`) to a Python object. + + *encoding* determines the encoding used to interpret any + `bytes` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding `str` objects. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity`` + and enable the use of the deprecated *parse_constant*. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. + + """ + return loads(fp.read(), + encoding=encoding, cls=cls, object_hook=object_hook, + parse_float=parse_float, parse_int=parse_int, + parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, + use_decimal=use_decimal, allow_nan=allow_nan, **kw) + + +def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, allow_nan=False, **kw): + """Deserialize ``s`` (a ``str`` or ``unicode`` instance containing a JSON + document) to a Python object. + + *encoding* determines the encoding used to interpret any + :class:`bytes` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity`` + and enable the use of the deprecated *parse_constant*. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. NOTE: You should use *object_hook* or *object_pairs_hook* instead + of subclassing whenever possible. + + """ + if (cls is None and encoding is None and object_hook is None and + parse_int is None and parse_float is None and + parse_constant is None and object_pairs_hook is None + and not use_decimal and not allow_nan and not kw): + return _default_decoder.decode(s) + if cls is None: + cls = JSONDecoder + if object_hook is not None: + kw['object_hook'] = object_hook + if object_pairs_hook is not None: + kw['object_pairs_hook'] = object_pairs_hook + if parse_float is not None: + kw['parse_float'] = parse_float + if parse_int is not None: + kw['parse_int'] = parse_int + if parse_constant is not None: + kw['parse_constant'] = parse_constant + if use_decimal: + if parse_float is not None: + raise TypeError("use_decimal=True implies parse_float=Decimal") + kw['parse_float'] = Decimal + if allow_nan: + kw['allow_nan'] = True + return cls(encoding=encoding, **kw).decode(s) + + +def _toggle_speedups(enabled): + from . import decoder as dec + from . import encoder as enc + from . import scanner as scan + c_make_encoder = _import_c_make_encoder() + if enabled: + dec.scanstring = dec.c_scanstring or dec.py_scanstring + enc.c_make_encoder = c_make_encoder + enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or + enc.py_encode_basestring_ascii) + scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner + else: + dec.scanstring = dec.py_scanstring + enc.c_make_encoder = None + enc.encode_basestring_ascii = enc.py_encode_basestring_ascii + scan.make_scanner = scan.py_make_scanner + dec.make_scanner = scan.make_scanner + global _default_decoder + _default_decoder = JSONDecoder() + global _default_encoder + _default_encoder = JSONEncoder() + +def simple_first(kv): + """Helper function to pass to item_sort_key to sort simple + elements to the top, then container elements. + """ + return (isinstance(kv[1], (list, dict, tuple)), kv[0]) diff --git a/lib/simplejson/_speedups.cp39-win_amd64.pyd b/lib/simplejson/_speedups.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..f6e8451da56388ea519cb7f6dad3bac866cc9816 GIT binary patch literal 39936 zcmeHw3w%_?_5a-@n}m?%M$Kvjlm&xC1q=i=7|~slz+Kotgdix7BqSRW$;;il5)?JK zi4d;~sjt$9H9S=GXDiiK5v@%CA)rV~eW27DAFUgsHntW6KJNc}X6Ej0mWciR|G(eo z|LW#*cV^BxbLPyMGiRQ6Q*cA86e~$m93EYlq@4)qV`tw#@MV&uVWV~plYTSg(CD2e z*P+oxl~rC_y{B%4r?kdaR$5zIr`ndgZ605(t*X|RH+R0RrmoyQ(PBwT7p&hrcJsUU z?n>KWJU`p5ZFmsr3H#37*oyG&-8`JI@4St_VDOO}f6BrU8-IlG(!=Mn@ZMeLZul1I z+51yA-p9f;5q`XZr-!rjysENFD$A%@VZJ1l-xV(%o;AvC!1|;ywxKb@E|8W1V&~At zbpt@k#3x270V&U91gEAs#J*mVN|2`wVFP7um!y6QcSusB z;73WROOpB!lr&1xNJ@Xt$94xqXMzfXdgP&o{+`c7)!m>1r!2)+3LdI^WVuM%z+$4O zyi_ffq_nXBfTwgJ9@0(3$Bv>Wax%##1JQ;j2@mQP`LUyx6YDu4tK$w}Nul>fQn4nm zPxN>_WdJjsv?BuDOPlaSeC$-N$6Z|q!ajm`;UT)8QX>4ZOVadk>i_SU(y}ByaiQY> z)UIf6D=kIo*}AnI?Wt(<6BVuht2?)oei9`~ovvu}vRzv0Sq05U)XVa-4k!UoJ9i2r zYgDvX9QqN!U0S=M{ZY{l>(adidq(z6ZHp#w=@ z6(6NVO-~{DY>MW9L?;xjK1FFhE;l_yto&0YszIhNCF_80omm1ZX%O%U)B-bdJl>PG+psuV)6|^9dzp-gd~|eb$k>}g^>Qew&FdC zCPP}?nn=ZfUh5Y*9yW4l`v;ICXI4AUYt+X+pV#GWMRO+V*54Jg;)n90P|9MIlAEZk zZAfXm%Ou^F$nz8Q6=Jlwr}sNuH<%$b_`RcBzhJeGRs8+B?`++A5J@yrbGv+hht|O| z?_rI_qlSL9A&NfsLC|slvj?T6Iz5FnpQf~|Nw+z!BbhaaO$)lTn^HPHi}C*@VY^LD z-ERMVri3v=!IzxZuf(3{jZ?JMO4f1h&Ah;jn7}eq0h(&V?|^qcFI~YsN=~oORUwaR z0ub4rO};YV=)P*%D!K8*5_NRri7~2K(f0I?q~w`Mu8MccTMu{=c9Li4HSeSj?TT$z z^xe*!H`n~Wx6H9CnHY^~#JA-CHpUk(KbO3)H&#pjF*y7C!P*aAm-xZ@LSlU(B}XDD zSf8=>gd~|HR)AB^-F-?SPwav>m7H3qZjD2OxwKAFYts>vBx!aVG?B91uErMx9GSgI z_*Q!33N(k}*e8u))!za-SHSwCPj$W3|E&q~J)3%U?=T^1>MDqu`dnZ-6ZI_r$tCJA z|H(1xkm>7FV|x}N@AQ3fu}cEe`xSkUzrSeJwGR2&rWlu1k{sGScV;4V#vat)XXF`z ze0k6LB399pDq^)gNh0RyL3J5hwdXi$w)P~|U%vBh#Hldf+TAo8&Fa zSz3Eg`V;Ja?ExA{l#ZSx1yb zU=TNbUtX+viQ?ZA&m_@yQyIvk9UQ>W#1N;J!@zN;lmUKJc0^Heu~4U??{nrHm+z@S zr=sGV=LxAM+rve+^`g=<5^>`D^&Y>D(W_qpc?=E$@1Nr1uO(hpKn;IxwdN@Cfe}aC(JG z^8t0DZas4`T9M3m3TyywK1B)4wi#}ls!rM96DAX#X8mvpf}GuQQ#vC@7w7;4`J2xo zK+btT1)bdVB_u%NcfiR_A5%gzC$g+Z<)*h-dUoPhYuRW5MX4pgH!^&uYG-3WXR!VR zxm)tma?_7l7L;}ti2-Me!AN-9NtBH|tsyLPj-eDpfA0o3f;&`69gb08k&<&zeOI?8 zAzO314+E3A4m|yLR7$E<(Mix^R@D$Y7kS`-bu^|x7(wm?vjWYYrf5TS>t~>ZOAZw_ zu;`v)lPL?0ePNQG=K&10BDG)If`tw?q-CMH?+CK@37xGVAk$e;=LG z`l?YhbDZ4tGs@te|&L#zivOY6dZw8faT}-sFubHr%9GDv8 z``SW7B=yn)jj0#a|<_d=NbrK(bi0s0&|8R7-1?F_bG{<@N;=KNEHBUCY#s`4FtB zgDM{u?r9`haVAkZ%)mEDHT>9!dY|C}3p5XkEH>7!xg|k8@aZZ|34Xp#ohk zu5?=-6{;YXxZt{ljv~hb$91d&7Ajhg6I05-EEM?(5oBulwmhxRr7cKtbbKD;@c%Wz z)pCFOCbSQq@Vpc;;ptOi|49=bpOT;T5p`^Yx+~|1&#MHQ(p#xbTkheZW?6h6O#;kQ zSnuXF9!c?ldG&Fs@gY%F?K-L(Gqt8K%oMlSkd;+;6sZBtFE_u%Fqo{S1JrYXQs`Cd zA4dcN$oI4%l64GP9r=;2JGE~Maz4538hMmma`?Yv6OfTt_sWQ~$P7yqm#AmwL(fC}UyqXSO2hE4==+^HC)d18ctDrP&A$ZHne(l>vY=U4uPSK%$d^i< ztFwq4MK|lazb6vyEvNRXZaoS(^{)RQ7_(^fblqBon!2?8PVEimQEpj4N#67C1*J>l zdPgM|eh8-!s53zVxoI?3rOupBDXvrz086dOXZ`w=V94rLv4 z`d`L0?_FqC4qOh_1v!W0z|FuafoYF^pzE08<_Q54Fz-km)_Vq1I1d&WxNNYXUvC~DSx65{1&wL1d~;+?R8z$ZvE{Ul(*7&j@R_&8R4+y;L?8gq|9 zmI$)ENv7Vjn3#zWGjC2zuX(zbw;|5o?~|K<0!pX9BL;oNm2*%I+-2na-ugK@x>3P- zsNhLdFh*4HoCaRON~41LLwENMH7b}lv>i2M^>b+LqJn?I5HNi09w|fD9f3wEUDrE4 zOd5I=5KLTqV+*w2fu=ZZUf2|Ao@1C|0pMhcHiCss@z3+2arCBT2n`z?29l5sYEQC$ zoToLVx1rWOZQqbZ8lC(_(b>_h(ZyS+2M5es{vPhu=){XR64aUjq;7o_V(?DR{Mj{x zBY$=^%8&5xGbxE8X*gB(V@9s)R82WT@^p}(1IkUWvKoy~hjZwlD0RF9AbXxcO40Jt zGntPOB8NohBL4z?4LT|ge3blScvL~!UZpC4ichg98=RrzHHLf z*@n1IxdjZ6v9oRqebj0TP@riXPiTGXTnWp6fax4=fL@r6nmZ|S@4{LXEdHm@JzEo@W>^_fJ7u>usA0+0Tmdu|GErHZK9(UJ%GV3RnAnisL+JFLAT0tDI?9a)lp(n z`yzNE?VA)a1Ir{bngxzy`&FY^E-(;_qI%d&33jVm7GcfOue8*|(ij#+6A8y^md6Xg zCSbh*z9_W5v49tww}AnhL-GA37h7w-=rxRY+Dg zb@9nrI<3$Mi$w+7z6x&s47G;X;g#G9mLZZ3tVeon$EQP~7FM}*2#M;Wdy4(9#l%Ik zeClGtKZFhx621`Z_@pyD0gPNPzpYyzpgL=_Q(ysD4P%e%R!o-p3gI1&N;42?2BL!_ zG7Ln9f!NLwnFa!N<*YVxM7Dvz{vk))%n^12VK)#j7a^4MAoC!s-};Tm}fMTnKUjAgop}(zg_CAFl)qm2b8xdN=og zFfRi}7CE3jsUM>z(Jh%mwxRKwkj_Fa`PFQOIJG7J%n__!lq1xVFLT7HE%_uzoZ6E2 zam1-Dc?(CdIuB^cQij0p1Gin#@kAT4^$YFGBMa6q^Raahv3|+ZmfKi+?sgpCYg5n0 zxBUfLq3lxdCFi}P#2%#;$_vD8_8spSobND|>P+ESI-bPKIV_7r{ z6G-gwVg`_%ul36h?52fNCJLY8(1NN=bNr9Mn1=c6UMv+yIcbU%WD|Zj?9QQm>CE{~ z-uPoGd{iU8;ivd|BRQuRD_6eWNWPUeXJ7E2u&Lz+J}v?uNBFoG{QX1Jywl>7Xy2am zA2h#3ZVr+Rw1D|e-cXOR0<)7La1%7$U^5nDBBWV%HG?u>T}Q@)3SvoWelG}LgE;&g zpW^A(PpSUQ$El;>hMFl8&D@%xiinw8fRgBzn1f(F?&RxnCvBW~fr*~P6%#cPMlRfE zG2p-j|K)&DBZ?-AoEt1%yoc$21S;w9f5R4eq!h&2A}^%W+vKKyU*f_1PJ0bYA~k2U!2w62Qt`}=Kb6*G;~kPK00{18c!dchBoSbSrgdOOLrp%SWBu$m26YSE>GXZ#Jb zxX*h2i5)GbWJ46zA_Jwp(zjnQ`9H04$&r;4ZI1Q&;oyBksbzWJKND5-uSuVRHfl%P zsa0a~Qk-wJzTEUVR!o>bGt-ZlM3)8Z|@WWz(LG@B(4dOn0YZO=q6ow`Q<^yu`uPLG(?%B%15b-GHL-q$BBjUh7+%tla z1{``}9k8cW)@^e0FGyJsh3(vFW@j!XXJslmo$58rnPa;+_zIRBhTV^%BE#FUn^=Z> z_;Nik)1=zyzC>L43wxw=1|lg2h`n^N1WCionCAeWqZ1+D_tdHcqgK2OD&< zOi6CK0+}$7*-*uv$%seK_Xh5A+O_Gn0gnJH= z%XN}(V+rItO>LISO<$WNiSHM>0(`aU()REjh~UNXA+gp&9=6eb*EIwi${gZBAV zu(qM%uqOo~kwEK7AQIp`>_!;p<6^orNQ>Ci8%`tT5cR6lV3XAIPlL5!F~{2K_Qsz1 zkmWR_P{v;%_Gu{K)t;n^@y6fSa(e3j&PpX3m10&n9Th{XVf0Ug-J$VOD^hIAEZ$=xsgsFbto(8NFty+)RTl3U+Gy zoGnYGo;#6@>i_J%4EljnapF!-j-~NxD!AWQgRb%VEl?f0iQVWJi^*vP%wN9CJBD4H z$k@<-*1(Q*>&a;#2zQJ7hyitrd%-oLTl@@3>IXOlK_|$>{sXG|i(}x54UQ6=!=@;) zI2a{#3npwtL1I&9Oj~dryJC{b9MuR~8Gn6DZ=9C*lQ=HPDyVr+%uC>gJ`PzX zB8K}3c{y*={gT0|GEkZ~nKb>#e!3d;yVdm5;`8NZ+5#gBXG*zk#%akXsJ7GMFYyiI zeD0WJoS&dU|D_7O!5A~G=|r^cJJ5 zRjmS3bhSlBSL-9Mjq`zb;1OW~;b0@&|N9g_*&O6m~8|{@+w?cLEXm0+6qa%n5A}1aG?r^u-716D_ zl-ReZTg}6c!CY*u9Fy0_VE}UGye$X5Hj}aR^Y$1{eWjoGl^dw9{M6r$bGGZ{buWPo z*PXWP$Bc`ex1o-7Q)wGQc2g-2hpb9qzR9J%Mk|TyaaaZi)AJqj!F&E4WV-dLsccT) z{4rye%W1JHJQVx{+q@vxttUW2`}zAK^TmmLPo0wQ101umd>ta+29|GjtTX3;92kNG zs~Z{6aMtZmxda%{4XoS!lple85En@nLw076#gU!;G|}H*EjPa~fGjpbRkKJ{cQc{~ zL4=DPA7mrIHIzUu)PCOUlxm5mdVNm#fRL|h0LQ`Q^T^H92apXeA5QW{3?Lg!KDXRV z=c|x!*AfHW z0J6A2)QoCe3L;eFf{>bVz3@D1Sk3DQADA7(su44Qs)6ARAnR6>{^`0c3+~unLoj z0c3HJicVyTPVx7n&|WrZ8TBoO1$OhLQ>iabt7LEE2u01IE^#Npq zYY%GqHIGA#MqoCN>n@&+|`#+gk z_BOyV(|}I9aEF37aKG8R97EcJ^`!DVn#1vB)p_Jg1LkM{h@2SGFlZcR!x@IQtKlhh z>x)?+2+tlD5Ca<3il!hEUQ{hblFuIPboCNLn~L#~%@QkGRdm+hUxj<982aMz>__?< zjFvjWhatd5gjS3i(vZ-{6pEc1i6KyIBs?yhFpFqB3ln%MF+9Px(})W3PS7oOT6khp zt6gDxp>%x+$B<*(q}U@E$x4opR$OfTc|T+%vJ%vuU>K1P>ellovr(?;7=)+u!4f)X z8*sk(C@}0e3OJjo-|07lm45JsScWk~IkPrKQ0pK$>OYB;*QM6EBZA zlzRbLja*QpxUZgL=_Ok2ZQof?)GL=!X*egu-A$sWtN!6bR`nd+x)f6sni*YyGkQ9y z7l(?LzYsiNok61#-OpSAuBL6D#BXzNv-7~0%Zas7< zW#$vh&sm#w22Wt2h%DIb4_)qCM#{h~N$l-(GAhpFR>q@+c(mlEu_7MlOmyo%0By;& z(Z1&c5m|~@mg{&H9CWp-w@}N|T+_)=IUJQqD9lnhY6e5uI4YY^c?K$nq0Ag*C)7m- zDwCl;naIS$fmgdajOL|w-axp87hyqbNKb@htpl79hPxJv;e6lOvN+M%;!VK;ecA$N zVC7L|?WZM7v($*kO2H^lKn>|^(F)N4(SBFM;sESKs{*J1NZl+@a1<#!r2++Dr)BeK z7owRwsvwFxMMQuyH5-riBbvseePksKKnE#Ok4QU3DiP_XNC_g16oG~d)>EVqk)0G# z5ZOTyJ0jaDl8s0kMR57CU=u}f%2d!w5$+F$L6>AmdoeK>#Qf;q9hM)hIQP`C?x37X z3c^i1362UaTiRGMA0_45knF-L<~A_)#nbt5@0q~{y1ZoA>@3DJe>=Mgji4nrL(%eg z@Fc{~Md!)i$&;v7F81Z~8+j7-&MiR!jk?u6fvmBTtnmb<6rgG6OH;aEo!<^ZEYnFE z>q;^gO-ZC%*-5%pWi_VN)@++@{RJTdmUZtzz?=?kpBkg>Z^^Aj8E|DLltBe-LIGbL z9ntRZj%mrQ2d-PU{$)JJtpUz1Kd^tut%L~b)>naO$yF&I&71)2!@XMuY#`ul4(!*h zTNrQ+0War3^xpsh{f+2g1gBkja7Y~BpG@lQy7jl0-~w5HpiuXpyn2;W>s0>uSA{MX zp9h(h7C-I8!9DM9hf(RuAG_Id@&I@P*Gjz?+?(2&b9sSiC0|+3ubIhPVu>nXZGso_iRP`4>c#(QG?cQ1#tQndmFwowX* zIvn?WQ3uG;tf3H;xEm$*jAZ-^*6UV3$F3g^h>A-9=6RJM!>|E+x4Z;e-C9Ddg|qty zn9xfF0rRhb1H1pAfL76ezZel(H@BE&u<*fvv36d<@|wS00c1-)ITN|*De$1RD3Y4x z1(Xk)z<5EXZ+?=JE&etlxfu64h$&AZtBM&6G6k&DjY4af^u7`>ThEObeKpt4crz1K zkb4uD0n*6{Izg}qt(~G<*MWPBe>>34?F}Opyh4P2YpsI&gJ8?4XljtRS;az^k?%&4 zS?#XC&(q0yIkm+pPR*O<)NVyr|Ik>j(|xA-Z1W36k$c;=6tpy{CHF?xn#@~J44+S= z9K>t`mzv<)=$;c@XkxR7qZg52YI%Ovaeisvj`cHOG!tWPd6o-)9=@xHE-;BZPTDI$ zdLI<|E+>Bux&vymak=C*RwwdYk5Zg7+GA6+ z+pu?l)4H_;X+Gw&(%pgW8Kf!N;qaLQ4SMg5&|Z24!QRlVWvs})5VC;1lCqz?APyn^ z!mD_8#4V_ck>U5$pL)kA+M6^~TkEI*+*HS2Cq^Q4zj)B}#yYg`*ugrTo0EvNT?3vy zq#}sk(2UUnLgR!Q6Q#@8pmizZw;vSPXT@xT*BBK!iz@QoIBJqU1MbrP5kWO@K*MYi z2S1i?AsJm_GLOgXkm{IB;qizHu!{Up4v z@^?AKgFbsg&e#XPekT31K8Sn6z&laLbaR7cF!(To)9yyc04`WFpA*aqx^qq^gjuHU# z&||0>o%bEv3y;l=R&HYR2l8Pr5fam5csp_C@^hmWxM3GjM}A+4UE&#dC}(r!n1kR_ z05igMYV8CwWx45^c$ij>IUgvwX);e#g16i>76~T?PK~`p#qQuGIQ@4NVq9mp=9W49 zcT|JoP4&#I1HrtbBsRL!@t%`r$~e~+Okq`_E~b9$((I874Ies?Wz$&HDq!7=K{#}$ z?i&cw^AH>b%1D#_Z=ju_)qQ|syzvcVr&0BkPwVTQjKbmKgWLCnRdy4bUO;8DiHNC; zZI`0ARuL0=leYjgPVL9*IqxGxnA0UUDQ1ADu4@HYZpuNzu@hgQM8^)Klim47A{*iJ z0(KLF`xqQEaZr+12v~B^%(`_f;N%(5lGI?}&^wkW{(&yv^L89XttSYAd+klXLRz=J zgCq@yZAfJub4-1BI|{<aYmcc&I}_|0@}E%_TK`g(_yP%j+f=4&Oo&?^ zWFU-XhdAL)c>j)pP@Tx@%)bKHe0rUX&624NTew^gU>Mh}cd*=iHp1j`h2;8_=APU| z^L8i$E8RMQ#LNc|$M5hZ_1?=Qyp&5Zl_jPl!7rQnMnTlxF(ffW9m9!+>Q;K60FBa% zE7$+zd?IB&2|6b8Mp}Q7%*>brmiN(|y7d)gCSCp-#k;f_5uL-Po)cAxkD+g^zn8pM zIf)Ulz6>6rKGTVkA_l6pNW%W{-CKc1RaHz7`CExJ{{@)FZf3pXh#5O-SFhspOe{%% zI+`ean0Dlc-V2xSDS`+X#e#prloC^}@atK;y>KvqX(Cq10}i7#>&wfSDWY7)6)!*r ztdqDo>sA|JJ!>$c!q^+32K5{!)wxt1CY5hF-X{YQjbjkiIdx`=dMl%uz@Rj=t9QcJ zV0?h&X`lfhKWD#s4pj9#)h=K?!UlRY+~br79zkw)p-ed8o~sc?sru*edOdWfjkq%N z#tZll+n7prqa&n)8{Pv0D{8!0PKa67o53-^C#Gyd!3Oro7w0|)y>8u!@18DLLsXdx zZdy>^*~A14y_NSo6{4%>(<~v9*=E2|77lfrm}MQ4n{Gvfs{ztp#H3YI*~9`y57$>~ zWD)hJd{lo51C~OJ%)0eTmJR#Y05So?oqL^`ckZu1(j&97-lR@Hw;!`1OdVefHnTk- zsyuF9(Sb2rwY-V&XZDT_d1P@WRX`;?RfPK zuh6-sj-9oHOzS8c&Az}*(0uJv^-RF;1Y8}0$Yw-*#i-w>%q1Uz@JM`!gwkS81`^{c zhy9Vm6b5^OU{f#aUCii~046sries>GEHRrBR8czoVCr3kN&@1;CWiGA&`|Al>$T@o z)zx8+7pbuv+TNbCS%$C@)(ln9rCNTc>v>u?5&ScjGMldfKj+jLDQX#GZwDje+Uske zfcNGGKaMF^9n0=zY-4#Xdq|dm^-pYm0IPonnS0iYTK9@>um7sMtfn z!;pl`n$9^O#!h5~7#ZqN+!v!Osp}l;AuNs@oBzrrdozD( zMYVbzW-~V*CKEo|Ls!_?wV2m}|A6BYqgT^EOvsmG-UK0hO%dB_2~ZfhM!7kUC}`OK z5;Q%RBL&fI!S8l4zNn9Sb@YsA0$kQz5F3rR5mSMXcmIT{8r~Y+ilkHf6PRO6A-`}D zW+GqGE5|KFh}tw2@#3&KPIr26jy!@Z<0>Xp0k#w1C7uz$aSRx(p?^Y9(h$2`8|8oK z6(DFjlWTgV;PW0+VKkp#3O*gc(j91SL2uCA+DtN|0vrKQIh!eWcz^OG(u0mIf0`%T zSSOf`a)d=SKL?KiAgMvTwh%xD{2mKsLqA+hx1cYeZi4G?7#Em`zI)GPaSTtfV4d3hvuaKhrBdvaIxS`MAFUHR*952w`cTK&_y^z|7 zJI52`7B5>22)V_}E=Gj8#rH2DoxMP^Vr_^7+~Ncvb?ajs#^x0*=FtpxA7IFxj0EWZ zavizF1oR)kB*5gRlVnI}%cd`wMZJ=x!@r)3t5tN1_FcfxFsV@Z6Ua&mU;aEsFm{f#ya)AtWURt8Fzc?&G~3% z_??Bq?<^JN(5xnJ(*#3_|f%zWMfyb3(p!i)8niHY#k7NS*(y{4^B{-*Gn5qVdz(^dHDN4 zVd(oDnlX^}6%NH96v^SY9E!7xNa({H$~L#cd|NowPPCYh@i(1TDVo+sh)5vLJhe^H zzybi7-;AV$y=k-(9gzJ3>}oF}NE$m!#oCN+o`&9M`o4FZLX$5<_|X27BMh zb`IIiAn!XRa0F8}=5M#!41@dpY?h;tWFthOfp~)>n92zede`kmj$m3RM5d8rD@QQJ z6C&F{+|3b8_k;+&(smO^lo*AV7&#Vkgxx^c4a7AJ5o(^Zl$IA*Q;F_l<7lIy2+iUG zZQ^L7$q3Ej0{wxNbP;MiLbHYx=sgV0n~|!44#eNgdx9aNU)y5)eF2MEAs?{+Eb?-z z6U4R}+&aA%q+6e2@P;4&g&5Lw>ptSuVlKWG0sQG{ZtW-tRdbO%Y;QAQ zZDaZ2^q)GDfRWp13z+oN&S%o=>N0TO!nntsmiu>e!`zn}+^;=_`&T4y%d{?{ZpM$K zF){hZB)TOD$|U4+qqnk37~sfWSrMK;=U6)I7{F7Mj*gUyuR3(yXWFB597yST5EF3p zU+n+l_wuItKQXBp^mR6FsH#aj;T*X>g12LZ+iMQaMIb-h&LqG)nH2WOAS%Gn_W=ET zTSgD<7lr*VpK7q5X2ZLjF6|4r_1 zz2EFWf%E?g8`%E@iOFX#CAgbxlS zrk@t@pIfv~*6)5Lcq9QwaZS4KnqXQY7{W6A*%DsZ!v^6fpipQUwuN?K%7U2zE&|57 zG|QjSWDflgx|#Y9(@}3ssC`6h(Dh>aF|OyqL<+u&@772^nQ6cKgl2wuCJB6pMeI5w z3+@~1oh{}kfe0=}HJN5?{%Kc|(7r2zb*Nb|iRy2HCj4W^04|{#Ls3B15&PYTgJTgp zessh!|A(gFO`ygx&=xMlOki4DUsbfdxEEn_wb;{I3gIi1MEu$eHIUt@^=T*3M9oKh z@>FXEWVi!Pq33fH5ty#yO6972tnUiRU@>J#GTzPWO~t;{87S!}t-5F~_cU@qYxKh^ zuLavM?GB@U9Eq~P_&Cb1n|zmp>11TfIzrj#cSdOc>w%6S?X#ocw@?Q8onZzI?&ko6 ziHq5<$NKD{5+_EKNcRBD5AmG7A-G0Oj;1%&pvC(XA<28=vyKs#=2uH_UPM#rWyQ;A z0ur`oI%KfnnmT?FOZ$SZahrqm<574Sh<;ED210%;GWaq9SaT4__cD>{(g5kV95v656kMR+7nSi183G6=1p^M;iE{C%WJ`N=P(nL1g z7kiHp7oRY?wI?F|OAw&?@n}h;R^@X_JL0zjv)_CzAXMyr^K8VZr|mZ{5V4);4wAC6 zom^)D<=h`={*$g0dt}TTI5E;j(Kti3G&%ZYtinPzajudFeECKtk z=&p+POh1@r>GakwPX0>RRM|=x)OD)jFHg^g1ay^wUp&HjCtYB|PA!1-l*q`!TX#!n zocJsHx4z%6$E^7^HC{1}|3ur_7AUlYqo3zA@5}-77bb6cZ^BP5K`dvFPgVTxG?Y@1 zhI3!=P_+Csr6nIns-Gf-t2mP=PTqsAdOM}*if{qEYPU_vLOd%!Mak+=utuQ| zI9-55R=6ld&g+2oFC!6PDT?w{7n0{w5-pFPG=m=|qe@Zy(>@*s#n1{}PZC}RG&uEw zS}ho#arkpNB!F!OEZ@Qp^?V1fic4-L;P85f2*EFi_dB5f#PlHE=%7f9;6u2lkO_(& z?N+q7v6Jc3xCvmZg4Tu zaGUpYaNOkf-h^NPi5(_ZAmxGcp9V97^M8QzGkDzrx~6s?WjZtH;faJ`6(YT#M&u7J z1SB{QUjiuF7shV<7O3c;C&-zhkM;y#HB%LRXZyc4`OaXjL(vZ5{7bi<&3ax@`XU@k z6pEf3`awmsM{&W>2J3lZ2=wJfwD*&{#bq+om-A|zmW`|u zuBNT2Eru97Bw)$cw6z)E86Q%+l&aXgRAwt}{PE{mkE~|3kjR z{KX8!Tg)dBP?AC(+!6r;B8k922 zdPFez=2H-RaP3t3{c2zF&Y^H16;mB)z9QEpyXEJu#b)#FG$r;A@^g-OC6L_c3MAu) z)>6}*Ie$>ov{}inoDNSC?i$8+xN_P(aE7}xdf&rYI|eASlvMvM>C$Z!OM+d-BaC5Y znzo1Q=k+D~bMqP;CV^<4?9DdMzO4qfy-7)e1K zY3h(WP|)Yt*+rxGLVziW^(=@L?V%vZl)wW=fyCIreyBty3`D=m@PdJEhGj(b>#p1r zqO%8rgdkr?JSk$Mf#%J4V%FSANO;RarG@>_013i?j-7k~LSseqdXlr923?T@gLw&t z_(DwK?Feam5Tnp8@Wf1juGfMeW^GC!zfZm&8o@0ND)c5Q6s6hmJ5`9n{F0&i#DEeT ze2IMqCrWGORJrlDnsy|#{H6}ZFkqFnZ-t+YE zo=W|bw}X-QQl2*>DsSfZYif&OGfEtuO0HR1=5+Oa>RW=t!IkCD{T)9NW0Zr*X?8%(cktI`$x^fdjUw(`~v zZ)%=q4&-98>F+eKu<`8ua>&4na31j*BA0uyJ5A@BNdh&L&J zXb1AnK$iXH&MADD&ug)~jHtZp<5b!IBVyEI*#W={XfdhfyvI3igm*P4MtZ$qZ|YJy zKCmf6yXaRwY&7ApjoEYM9^XREd@HE^or(L+HzE>C49t~?IganKsoDNc8|E2EqOpOP zi`Y(3b>&@fG%x3<4_gl4l-_++x6k6YWCCk%IqyOQ)W{h$FIF+1^6yP4=D!M!0X?gI zs(B{RYHeeRYRv4|wm< z$Btu3gLOn0UQg-W5#UfX0za1eiG2Z#3cC8_ZoG>HdtKU74sY>l1lSgw2S_VJ?g=9} z@iX+!G$Y20@)JcfX)smvacThE!BW^Sx8>Js3oth-C<%Vq&pBnHPazk}k#%5gHWbf) zTpMt-!T*&X^*%0k<-I(-Q-lwS@M#ftiSQ#4_KPr8u^p6yXLDHi)o7gmXlgEy4>$I821c@8SGE5Mh@HUlQRHB79JUH;DQu zBAg<^3q^Rg2;)R}T-4*QB0M6(og(ZK`r9VbABc8q5an^{B&p~|sgXtbg>%k*BL_1sj1ctE=^@r39=g^{jMzyi&QlqSRNdO00ah zN2;nVcN2c4dv&qb4KAf>od+dXsqUIcfVa-07K6yk#gx1&tLh_?EPsu=rcR(mevt{l zn$m{i>blw$#onr0+?1}W9grR~cEK3Rn>lH6)|AUGzao2T>GCoNOECISS+*LVm;Ij` zC~Q3b(TzIZ#uPB#w!E&c+E(YW&8e$(kGIi(!x=kqQf4NErFw{r9=GcA)Y>LbP<{2) zZs4d+#5?1bxpObPh+>ko0P`gegez6u&Jt{?@Yji z?1|tvZhQos8fDyVnjFN^I(CbVTstJmg@= zmQWrj-7Ba`*oXT|;qO8M8sueY7Se_!QC|}-R;r@%;Gb5wRq8_I<0P~Q<5yhf^LWsc zi@Ad?u6Ez-rWPbSDkkM}bxF0pYSyr-2S$%iMQOE{BIE;0$s2i@i<9PZhvg})UE$_l z)5bc67d``e0yXOgWpyyGDmG#eH({r4Pliys@S)%9R*e1?GVcNX3;bXYS}raDqO8rB5pW8 zDh!{n9wUkvkKv}AzPK|U&oY&%(~4lstoRIYVMG@tC6SXu&o;cW&1-`%GW<6$ZK5ro z56?tK9jAdrLLwc7%B?KFprUsGIMn@_dhXdK`kG$k*~U%x(&wp z2!iOr8f1{dQ(6U6vRy}`5o@|6D5Vq~y{2BZsdan=a&t$*SsHb4RtP=NL}Lc^XA-H_ zR?U2Bbk@Olv%@KE^QBu|Kc zRHU(Sa5*c5d<`OO6ybUiwu+G09Txvv(SkVr=F&%vc*Fpe23SmEOyefN#v{HLp$*}v zGD-Rwo*j5fF>tKHGqX~XuEgL}jYq>Wq!rIQv>Zvn{Cut?l@wz`CK>Wyb5(}Ct`agx~wAm=3jU<}i zPRz_Kt5|`5Lx^-DrE}-u$RpQ5G;IT6*u}?x5E%e-GP)fDVT^9)09YZTYaa+>be#iW zMU1WsFjQ5@QpD)G0c(R^jQ`z;HkHbWhW#(c^qcFd)f!5F1o#rPkAUBD%Pr;0Q9a4- z7xnX~W$8`_D@PdsyMdKc7zr!lbc+VSI9&-~-Ek6*57St3_(ZD3xI%e_EOt+*aCJDCocCwGx+R4n&c7x zi&APA^3eaBOwXJ|Bqe}tLdgCPWjek2{;z#ZxKfZbB}q#1#MiTUd4iNQIlg|hIYu&1 ziI>{YSiA95LnospDE|}~WRIXRVP89ba!jUac# ztPGPSUb3KU%j6+_3H9-Xqm$#Lyf3Z$?5@=IGwaVNNhvfX#z=`z$JLJ>87GY_mr=iA(vC5hTV8-; z{3fa&$^l$D1y1#jrbV+v<|Tcp@scfFlCH+H!_e0$F!+xM8j~C^$+s&~D~lPGyq=yOW4WSKXl zo8eOvrPT7FQfd+9M|@sNJ!_+TV?-NUq|9q1=@6cxW}}UxVenrOG%WwTm`r9PYNFI_ zm!vDu>+U!5p8`XFil8x#7$S{OGo=ybmr5h1Oq52L#!D%MIG;$0m6Eo^cgJNy4$^0l zEH%!=Z>cDf^eOSS#7P#^FR?s@!JL5ihssQff}zhu=n4EAUp}t&`8_wqXp20WOHKV_^v&Gmw;aA^vhiTJ*cZGgyD-b)&o}P0~Skeks}* zY5IIwf^h}uzl!uVh|h|^Q+jp;e#mh2-6^s(+;et!YW-Oy#McBqYeBPOxg?!LnnGb0 zg(kcOFoe>HaZ;iuu0FN|{(#yJebO>HzHkV&#TfJh!RKK(@rx>9V>LW9#h9cRvWHZ% zT|Dp~snAF2WY{iFwyjfkNmkQNDV}Ip|3RBhF-!89=s#3Je2kFdw1G&WXFZ1#j-E=S ztMSz1p|;qFFk+=Lv8+C8>V#F;a4W4TzkJHX`qkx9;c6#VEydy0OxoR)hks_!V15(# z`vz>+eTdb=H)AX?K1uh(HUVpofV~RXcEGwKV4niE6R^QP|2`MtTG=l4!~cCQ!u9d* za}lnG|J8Gm!?A^{7x3-OVkdU2Jhd~h$2nhYY-Ym|-{7UKNNUjTN zz1YQcmy_`HNI!uTdDw?4&c!BdX;rOP+80qC2{PA%vm5NFo*Q3`3eqkgH0q->94Bdy z&KHX3xK~Nr4Nj1Uv}EZ0@erI&LtcuMCL~3hLBrK6y&fS0@$NHI*}0YOvXw>)&U1U| zZ09I)iq;YZE(mLZ^l#8hZ73P+JrEXq&1#qmjZSb$x(vL+76Nc6?^NW?ZzwAc?{HvL z7bVZA^VHD6)DuP>m}%)uVz#3p)HNSGX3!ayu)R)06O4Wl=@K?s?DWp}EmxU!{xL!e zv+Am_zC6yj=DBJAqfGjZkN}do>&dDMO4YJTI=<)!Z*FGAj`DJ0#?r5i5*=RF_Uj_} zJG|Vgc1OUd^^MwXr@Uls)Ev^+Vg9*nUm9zd8;lyvLwCX+;Y|ZH>`+yYVG&OYxpS-C zr5qcQu)? zB+aLm5}mOC`=CYj%o#{SLFw?UC@!QkP1boFwdJ$jt5?-|aJV)J<&&L}C9sV&>GnuD zX|GV)5QadZm<5vFSuZT4?k>GSB^z!~w8b=(O69Xt3f7Gz>09J0&L=lH&|h81`Z*OP z3|E>LE|FE@`hn>=b*PY6Y7Rr^VXIV1I;EzP^yi57gPD7|Yo3c*^Pa&RpYYf9Y;QGd z>cn!|_5WwwJheKvv>wNrZgN%BhrcD53#Z4vdZVYJJ|5U!oEezwsywQ%w0c1;ri1XV ze>E?t#o0-DwYxmjVP;fSv!1`r4CDD;^yC;cd@etuDD>2o;n+pmYc4X*P7RQ>J=W^W(%rE%0j|Pc?k(>5)QJ`JP^k^l z=DC;qR^VXAkaBkZygB)vBoV;=d%3XVNm8Y(jPKU8* z8$f~Ks{xAvm7h?Jk&4|77_Bk%4?vua@`w)mg!NSuYP__fm{8_XC&1HJPN<$VVbTOq zpD4TsZHnfIB)N*TQC1&1H5qJW@c$O<&{Erxir01BxV3MrXESRG1|r*L{RY zQThHbR^IHBB0_`Cz>kcg!`2h~7<2|cBZ@8~iq61iN6}?Q(HZ!{C^}5W*vBZ>z?VeP z*`w$Te0>z15=Cd=1^5Spwp+RWuxHgsG8x)#UBlBm1s;203~$Jd14Wkpp#mDCJMHk)P27X5rU37aJ`1U9|dsICP zd{-1*Nfe!d?~bC2whIH_7e%)zil2d(e7t>)eziV|&cLTc(X~d=8F&Hy!N6$8A0q4& zcG)28)VM^}xcOhR%P&`Px^F~il-DTq6K!wNcCbUxMYmIB1Lte#yFsMai?B_EJ49&c zJ4Nv86!4=W3>tJIG~~>Omh4gI8Ozf!o_&*L7c)$JZ*^ABm6BMI*V&ZctXTe&XlChR_v_i8{tzp)0l#I2yi@IcO+cjb{hq6t2ay7jg7i zY4uD=dL8k|vpYKfGtT-bJdP*%E*{bgvFV6YI1f)Y;w1>Jv-p{t4Po(YobM1E;W_gV zB5p%?#X{H*;&z1JTn~LA4*w}#jECfh50id`3vr}R_%UfV&WOlPBG2vUZ11LJphsQ` z594_iaSG4EnH|L`yjaBT2$$gr0#5fqcHp72K??Vkz}^t=MCd9-`RjOi+j8g@@kWHd z!$b14Bm5E%$w}e9GD-RkXgU#2Er*>TeigzXo=(Im9O~v}Wh0y~;zbApBHoH{uZUCl zwut`;VO#~zOX0a9J`&*?Jl}vng-;4Nh3z6v;gA(v&SZo?7V%98Gb_>7>v&j%IMs#1 z+eEw-;jt>L;nw3G0K&8zx&A4{`wFZr79otoo{kMPcy~d%7f&YQcz=Q2mB~h&?%B}U zX9mimkj_0RP9dFLQk+6Me|(bUL8##Q67he3{!h|?3A;D219XLfQsKN9fXqSQ_>lui z0j{JSeR*Iemb&Vih8o;W#48@y^PVU{LFwHV*Gd;HG+|yQb7Tp{A+tIeO?Pyzfo3uG`bH?V(&4rsA zH?QBkX>;4=?VERQ#+-9Y;+BjpnOh3CEZWkrrEyE!mhD?Qw{&gk+tR-!ZL4joeXGLC zZ{50S>&~t1Tf4VrZ?kVJ+}5zIaa-%QP21YGZQr(I+s= (3, 4): + from importlib import reload as reload_module + else: + from imp import reload as reload_module + def b(s): + return bytes(s, 'latin1') + from io import StringIO, BytesIO + text_type = str + binary_type = bytes + string_types = (str,) + integer_types = (int,) + unichr = chr + +long_type = integer_types[-1] diff --git a/lib/simplejson/decoder.py b/lib/simplejson/decoder.py new file mode 100644 index 0000000..7ed5ea8 --- /dev/null +++ b/lib/simplejson/decoder.py @@ -0,0 +1,416 @@ +"""Implementation of JSONDecoder +""" +from __future__ import absolute_import +import re +import sys +import struct +from .compat import PY3, unichr +from .scanner import make_scanner, JSONDecodeError + +def _import_c_scanstring(): + try: + from ._speedups import scanstring + return scanstring + except ImportError: + return None +c_scanstring = _import_c_scanstring() + +# NOTE (3.1.0): JSONDecodeError may still be imported from this module for +# compatibility, but it was never in the __all__ +__all__ = ['JSONDecoder'] + +FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL + +def _floatconstants(): + if sys.version_info < (2, 6): + _BYTES = '7FF80000000000007FF0000000000000'.decode('hex') + nan, inf = struct.unpack('>dd', _BYTES) + else: + nan = float('nan') + inf = float('inf') + return nan, inf, -inf + +NaN, PosInf, NegInf = _floatconstants() + +_CONSTANTS = { + '-Infinity': NegInf, + 'Infinity': PosInf, + 'NaN': NaN, +} + +STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS) +BACKSLASH = { + '"': u'"', '\\': u'\\', '/': u'/', + 'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t', +} + +DEFAULT_ENCODING = "utf-8" + +if hasattr(sys, 'get_int_max_str_digits'): + bounded_int = int +else: + def bounded_int(s, INT_MAX_STR_DIGITS=4300): + """Backport of the integer string length conversion limitation + + https://docs.python.org/3/library/stdtypes.html#int-max-str-digits + """ + if len(s) > INT_MAX_STR_DIGITS: + raise ValueError("Exceeds the limit (%s) for integer string conversion: value has %s digits" % (INT_MAX_STR_DIGITS, len(s))) + return int(s) + + +def scan_four_digit_hex(s, end, _m=re.compile(r'^[0-9a-fA-F]{4}$').match): + """Scan a four digit hex number from s[end:end + 4] + """ + msg = "Invalid \\uXXXX escape sequence" + esc = s[end:end + 4] + if not _m(esc): + raise JSONDecodeError(msg, s, end - 2) + try: + return int(esc, 16), end + 4 + except ValueError: + raise JSONDecodeError(msg, s, end - 2) + +def py_scanstring(s, end, encoding=None, strict=True, + _b=BACKSLASH, _m=STRINGCHUNK.match, _join=u''.join, + _PY3=PY3, _maxunicode=sys.maxunicode, + _scan_four_digit_hex=scan_four_digit_hex): + """Scan the string s for a JSON string. End is the index of the + character in s after the quote that started the JSON string. + Unescapes all valid JSON string escape sequences and raises ValueError + on attempt to decode an invalid string. If strict is False then literal + control characters are allowed in the string. + + Returns a tuple of the decoded string and the index of the character in s + after the end quote.""" + if encoding is None: + encoding = DEFAULT_ENCODING + chunks = [] + _append = chunks.append + begin = end - 1 + while 1: + chunk = _m(s, end) + if chunk is None: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + prev_end = end + end = chunk.end() + content, terminator = chunk.groups() + # Content is contains zero or more unescaped string characters + if content: + if not _PY3 and not isinstance(content, unicode): + content = unicode(content, encoding) + _append(content) + # Terminator is the end of string, a literal control character, + # or a backslash denoting that an escape sequence follows + if terminator == '"': + break + elif terminator != '\\': + if strict: + msg = "Invalid control character %r at" + raise JSONDecodeError(msg, s, prev_end) + else: + _append(terminator) + continue + try: + esc = s[end] + except IndexError: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + # If not a unicode escape sequence, must be in the lookup table + if esc != 'u': + try: + char = _b[esc] + except KeyError: + msg = "Invalid \\X escape sequence %r" + raise JSONDecodeError(msg, s, end) + end += 1 + else: + # Unicode escape sequence + uni, end = _scan_four_digit_hex(s, end + 1) + # Check for surrogate pair on UCS-4 systems + # Note that this will join high/low surrogate pairs + # but will also pass unpaired surrogates through + if (_maxunicode > 65535 and + uni & 0xfc00 == 0xd800 and + s[end:end + 2] == '\\u'): + uni2, end2 = _scan_four_digit_hex(s, end + 2) + if uni2 & 0xfc00 == 0xdc00: + uni = 0x10000 + (((uni - 0xd800) << 10) | + (uni2 - 0xdc00)) + end = end2 + char = unichr(uni) + # Append the unescaped character + _append(char) + return _join(chunks), end + + +# Use speedup if available +scanstring = c_scanstring or py_scanstring + +WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS) +WHITESPACE_STR = ' \t\n\r' + +def JSONObject(state, encoding, strict, scan_once, object_hook, + object_pairs_hook, memo=None, + _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state + # Backwards compatibility + if memo is None: + memo = {} + memo_get = memo.setdefault + pairs = [] + # Use a slice to prevent IndexError from being raised, the following + # check will raise a more specific ValueError if the string is empty + nextchar = s[end:end + 1] + # Normally we expect nextchar == '"' + if nextchar != '"': + if nextchar in _ws: + end = _w(s, end).end() + nextchar = s[end:end + 1] + # Trivial empty object + if nextchar == '}': + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + 1 + pairs = {} + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + 1 + elif nextchar != '"': + raise JSONDecodeError( + "Expecting property name enclosed in double quotes or '}'", + s, end) + end += 1 + while True: + key, end = scanstring(s, end, encoding, strict) + key = memo_get(key, key) + + # To skip some function call overhead we optimize the fast paths where + # the JSON key separator is ": " or just ":". + if s[end:end + 1] != ':': + end = _w(s, end).end() + if s[end:end + 1] != ':': + raise JSONDecodeError("Expecting ':' delimiter", s, end) + + end += 1 + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + value, end = scan_once(s, end) + pairs.append((key, value)) + + try: + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + end += 1 + + if nextchar == '}': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting ',' delimiter or '}'", s, end - 1) + + try: + nextchar = s[end] + if nextchar in _ws: + end += 1 + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + + end += 1 + if nextchar != '"': + raise JSONDecodeError( + "Expecting property name enclosed in double quotes", + s, end - 1) + + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + pairs = dict(pairs) + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + +def JSONArray(state, scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): + (s, end) = state + values = [] + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + # Look-ahead for trivial empty array + if nextchar == ']': + return values, end + 1 + elif nextchar == '': + raise JSONDecodeError("Expecting value or ']'", s, end) + _append = values.append + while True: + value, end = scan_once(s, end) + _append(value) + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + end += 1 + if nextchar == ']': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting ',' delimiter or ']'", s, end - 1) + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + return values, end + +class JSONDecoder(object): + """Simple JSON decoder + + Performs the following translations in decoding by default: + + +---------------+-------------------+ + | JSON | Python | + +===============+===================+ + | object | dict | + +---------------+-------------------+ + | array | list | + +---------------+-------------------+ + | string | str, unicode | + +---------------+-------------------+ + | number (int) | int, long | + +---------------+-------------------+ + | number (real) | float | + +---------------+-------------------+ + | true | True | + +---------------+-------------------+ + | false | False | + +---------------+-------------------+ + | null | None | + +---------------+-------------------+ + + When allow_nan=True, it also understands + ``NaN``, ``Infinity``, and ``-Infinity`` as + their corresponding ``float`` values, which is outside the JSON spec. + + """ + + def __init__(self, encoding=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, strict=True, + object_pairs_hook=None, allow_nan=False): + """ + *encoding* determines the encoding used to interpret any + :class:`str` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + Note that currently only encodings that are a superset of ASCII work, + strings of other encodings should be passed in as :class:`unicode`. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *allow_nan*, if True (default false), will allow the parser to + accept the non-standard floats ``NaN``, ``Infinity``, and ``-Infinity``. + + *parse_constant*, if specified, will be + called with one of the following strings: ``'-Infinity'``, + ``'Infinity'``, ``'NaN'``. It is not recommended to use this feature, + as it is rare to parse non-compliant JSON containing these values. + + *strict* controls the parser's behavior when it encounters an + invalid control character in a string. The default setting of + ``True`` means that unescaped control characters are parse errors, if + ``False`` then control characters will be allowed in strings. + + """ + if encoding is None: + encoding = DEFAULT_ENCODING + self.encoding = encoding + self.object_hook = object_hook + self.object_pairs_hook = object_pairs_hook + self.parse_float = parse_float or float + self.parse_int = parse_int or bounded_int + self.parse_constant = parse_constant or (allow_nan and _CONSTANTS.__getitem__ or None) + self.strict = strict + self.parse_object = JSONObject + self.parse_array = JSONArray + self.parse_string = scanstring + self.memo = {} + self.scan_once = make_scanner(self) + + def decode(self, s, _w=WHITESPACE.match, _PY3=PY3): + """Return the Python representation of ``s`` (a ``str`` or ``unicode`` + instance containing a JSON document) + + """ + if _PY3 and isinstance(s, bytes): + s = str(s, self.encoding) + obj, end = self.raw_decode(s) + end = _w(s, end).end() + if end != len(s): + raise JSONDecodeError("Extra data", s, end, len(s)) + return obj + + def raw_decode(self, s, idx=0, _w=WHITESPACE.match, _PY3=PY3): + """Decode a JSON document from ``s`` (a ``str`` or ``unicode`` + beginning with a JSON document) and return a 2-tuple of the Python + representation and the index in ``s`` where the document ended. + Optionally, ``idx`` can be used to specify an offset in ``s`` where + the JSON document begins. + + This can be used to decode a JSON document from a string that may + have extraneous data at the end. + + """ + if idx < 0: + # Ensure that raw_decode bails on negative indexes, the regex + # would otherwise mask this behavior. #98 + raise JSONDecodeError('Expecting value', s, idx) + if _PY3 and not isinstance(s, str): + raise TypeError("Input string must be text, not bytes") + # strip UTF-8 bom + if len(s) > idx: + ord0 = ord(s[idx]) + if ord0 == 0xfeff: + idx += 1 + elif ord0 == 0xef and s[idx:idx + 3] == '\xef\xbb\xbf': + idx += 3 + return self.scan_once(s, idx=_w(s, idx).end()) diff --git a/lib/simplejson/encoder.py b/lib/simplejson/encoder.py new file mode 100644 index 0000000..ed3f281 --- /dev/null +++ b/lib/simplejson/encoder.py @@ -0,0 +1,740 @@ +"""Implementation of JSONEncoder +""" +from __future__ import absolute_import +import re +from operator import itemgetter +# Do not import Decimal directly to avoid reload issues +import decimal +from .compat import binary_type, text_type, string_types, integer_types, PY3 +def _import_speedups(): + try: + from . import _speedups + return _speedups.encode_basestring_ascii, _speedups.make_encoder + except ImportError: + return None, None +c_encode_basestring_ascii, c_make_encoder = _import_speedups() + +from .decoder import PosInf +from .raw_json import RawJSON + +ESCAPE = re.compile(r'[\x00-\x1f\\"]') +ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') +HAS_UTF8 = re.compile(r'[\x80-\xff]') +ESCAPE_DCT = { + '\\': '\\\\', + '"': '\\"', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', +} +for i in range(0x20): + #ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) + ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) +del i + +FLOAT_REPR = repr + +def encode_basestring(s, _PY3=PY3, _q=u'"'): + """Return a JSON representation of a Python string + + """ + if _PY3: + if isinstance(s, bytes): + s = str(s, 'utf-8') + elif type(s) is not str: + # convert an str subclass instance to exact str + # raise a TypeError otherwise + s = str.__str__(s) + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = unicode(s, 'utf-8') + elif type(s) not in (str, unicode): + # convert an str subclass instance to exact str + # convert a unicode subclass instance to exact unicode + # raise a TypeError otherwise + if isinstance(s, str): + s = str.__str__(s) + else: + s = unicode.__getnewargs__(s)[0] + def replace(match): + return ESCAPE_DCT[match.group(0)] + return _q + ESCAPE.sub(replace, s) + _q + + +def py_encode_basestring_ascii(s, _PY3=PY3): + """Return an ASCII-only JSON representation of a Python string + + """ + if _PY3: + if isinstance(s, bytes): + s = str(s, 'utf-8') + elif type(s) is not str: + # convert an str subclass instance to exact str + # raise a TypeError otherwise + s = str.__str__(s) + else: + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = unicode(s, 'utf-8') + elif type(s) not in (str, unicode): + # convert an str subclass instance to exact str + # convert a unicode subclass instance to exact unicode + # raise a TypeError otherwise + if isinstance(s, str): + s = str.__str__(s) + else: + s = unicode.__getnewargs__(s)[0] + def replace(match): + s = match.group(0) + try: + return ESCAPE_DCT[s] + except KeyError: + n = ord(s) + if n < 0x10000: + #return '\\u{0:04x}'.format(n) + return '\\u%04x' % (n,) + else: + # surrogate pair + n -= 0x10000 + s1 = 0xd800 | ((n >> 10) & 0x3ff) + s2 = 0xdc00 | (n & 0x3ff) + #return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) + return '\\u%04x\\u%04x' % (s1, s2) + return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' + + +encode_basestring_ascii = ( + c_encode_basestring_ascii or py_encode_basestring_ascii) + +class JSONEncoder(object): + """Extensible JSON encoder for Python data structures. + + Supports the following objects and types by default: + + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict, namedtuple | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str, unicode | string | + +-------------------+---------------+ + | int, long, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ + + To extend this to recognize other objects, subclass and implement a + ``.default()`` method with another method that returns a serializable + object for ``o`` if possible, otherwise it should call the superclass + implementation (to raise ``TypeError``). + + """ + item_separator = ', ' + key_separator = ': ' + + def __init__(self, skipkeys=False, ensure_ascii=True, + check_circular=True, allow_nan=False, sort_keys=False, + indent=None, separators=None, encoding='utf-8', default=None, + use_decimal=True, namedtuple_as_object=True, + tuple_as_array=True, bigint_as_string=False, + item_sort_key=None, for_json=False, ignore_nan=False, + int_as_string_bitcount=None, iterable_as_array=False): + """Constructor for JSONEncoder, with sensible defaults. + + If skipkeys is false, then it is a TypeError to attempt + encoding of keys that are not str, int, long, float or None. If + skipkeys is True, such items are simply skipped. + + If ensure_ascii is true, the output is guaranteed to be str + objects with all incoming unicode characters escaped. If + ensure_ascii is false, the output will be unicode object. + + If check_circular is true, then lists, dicts, and custom encoded + objects will be checked for circular references during encoding to + prevent an infinite recursion (which would cause an OverflowError). + Otherwise, no such check takes place. + + If allow_nan is true (default: False), then out of range float + values (nan, inf, -inf) will be serialized to + their JavaScript equivalents (NaN, Infinity, -Infinity) + instead of raising a ValueError. See + ignore_nan for ECMA-262 compliant behavior. + + If sort_keys is true, then the output of dictionaries will be + sorted by key; this is useful for regression tests to ensure + that JSON serializations can be compared on a day-to-day basis. + + If indent is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, separators should be an (item_separator, key_separator) + tuple. The default is (', ', ': ') if *indent* is ``None`` and + (',', ': ') otherwise. To get the most compact JSON representation, + you should specify (',', ':') to eliminate whitespace. + + If specified, default is a function that gets called for objects + that can't otherwise be serialized. It should return a JSON encodable + version of the object or raise a ``TypeError``. + + If encoding is not None, then all input strings will be + transformed into unicode using that encoding prior to JSON-encoding. + The default is UTF-8. + + If use_decimal is true (default: ``True``), ``decimal.Decimal`` will + be supported directly by the encoder. For the inverse, decode JSON + with ``parse_float=decimal.Decimal``. + + If namedtuple_as_object is true (the default), objects with + ``_asdict()`` methods will be encoded as JSON objects. + + If tuple_as_array is true (the default), tuple (and subclasses) will + be encoded as JSON arrays. + + If *iterable_as_array* is true (default: ``False``), + any object not in the above table that implements ``__iter__()`` + will be encoded as a JSON array. + + If bigint_as_string is true (not the default), ints 2**53 and higher + or lower than -2**53 will be encoded as strings. This is to avoid the + rounding that happens in Javascript otherwise. + + If int_as_string_bitcount is a positive number (n), then int of size + greater than or equal to 2**n or lower than or equal to -2**n will be + encoded as strings. + + If specified, item_sort_key is a callable used to sort the items in + each dictionary. This is useful if you want to sort items other than + in alphabetical order by key. + + If for_json is true (not the default), objects with a ``for_json()`` + method will use the return value of that method for encoding as JSON + instead of the object. + + If *ignore_nan* is true (default: ``False``), then out of range + :class:`float` values (``nan``, ``inf``, ``-inf``) will be serialized + as ``null`` in compliance with the ECMA-262 specification. If true, + this will override *allow_nan*. + + """ + + self.skipkeys = skipkeys + self.ensure_ascii = ensure_ascii + self.check_circular = check_circular + self.allow_nan = allow_nan + self.sort_keys = sort_keys + self.use_decimal = use_decimal + self.namedtuple_as_object = namedtuple_as_object + self.tuple_as_array = tuple_as_array + self.iterable_as_array = iterable_as_array + self.bigint_as_string = bigint_as_string + self.item_sort_key = item_sort_key + self.for_json = for_json + self.ignore_nan = ignore_nan + self.int_as_string_bitcount = int_as_string_bitcount + if indent is not None and not isinstance(indent, string_types): + indent = indent * ' ' + self.indent = indent + if separators is not None: + self.item_separator, self.key_separator = separators + elif indent is not None: + self.item_separator = ',' + if default is not None: + self.default = default + self.encoding = encoding + + def default(self, o): + """Implement this method in a subclass such that it returns + a serializable object for ``o``, or calls the base implementation + (to raise a ``TypeError``). + + For example, to support arbitrary iterators, you could + implement default like this:: + + def default(self, o): + try: + iterable = iter(o) + except TypeError: + pass + else: + return list(iterable) + return JSONEncoder.default(self, o) + + """ + raise TypeError('Object of type %s is not JSON serializable' % + o.__class__.__name__) + + def encode(self, o): + """Return a JSON string representation of a Python data structure. + + >>> from simplejson import JSONEncoder + >>> JSONEncoder().encode({"foo": ["bar", "baz"]}) + '{"foo": ["bar", "baz"]}' + + """ + # This is for extremely simple cases and benchmarks. + if isinstance(o, binary_type): + _encoding = self.encoding + if (_encoding is not None and not (_encoding == 'utf-8')): + o = text_type(o, _encoding) + if isinstance(o, string_types): + if self.ensure_ascii: + return encode_basestring_ascii(o) + else: + return encode_basestring(o) + # This doesn't pass the iterator directly to ''.join() because the + # exceptions aren't as detailed. The list call should be roughly + # equivalent to the PySequence_Fast that ''.join() would do. + chunks = self.iterencode(o) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + if self.ensure_ascii: + return ''.join(chunks) + else: + return u''.join(chunks) + + def iterencode(self, o): + """Encode the given object and yield each string + representation as available. + + For example:: + + for chunk in JSONEncoder().iterencode(bigobject): + mysocket.write(chunk) + + """ + if self.check_circular: + markers = {} + else: + markers = None + if self.ensure_ascii: + _encoder = encode_basestring_ascii + else: + _encoder = encode_basestring + if self.encoding != 'utf-8' and self.encoding is not None: + def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding): + if isinstance(o, binary_type): + o = text_type(o, _encoding) + return _orig_encoder(o) + + def floatstr(o, allow_nan=self.allow_nan, ignore_nan=self.ignore_nan, + _repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf): + # Check for specials. Note that this type of test is processor + # and/or platform-specific, so do tests which don't depend on + # the internals. + + if o != o: + text = 'NaN' + elif o == _inf: + text = 'Infinity' + elif o == _neginf: + text = '-Infinity' + else: + if type(o) != float: + # See #118, do not trust custom str/repr + o = float(o) + return _repr(o) + + if ignore_nan: + text = 'null' + elif not allow_nan: + raise ValueError( + "Out of range float values are not JSON compliant: " + + repr(o)) + + return text + + key_memo = {} + int_as_string_bitcount = ( + 53 if self.bigint_as_string else self.int_as_string_bitcount) + if (c_make_encoder is not None and self.indent is None): + _iterencode = c_make_encoder( + markers, self.default, _encoder, self.indent, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.allow_nan, key_memo, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.ignore_nan, decimal.Decimal, self.iterable_as_array) + else: + _iterencode = _make_iterencode( + markers, self.default, _encoder, self.indent, floatstr, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array, + int_as_string_bitcount, + self.item_sort_key, self.encoding, self.for_json, + self.iterable_as_array, Decimal=decimal.Decimal) + try: + return _iterencode(o, 0) + finally: + key_memo.clear() + + +class JSONEncoderForHTML(JSONEncoder): + """An encoder that produces JSON safe to embed in HTML. + + To embed JSON content in, say, a script tag on a web page, the + characters &, < and > should be escaped. They cannot be escaped + with the usual entities (e.g. &) because they are not expanded + within ' + self.assertEqual( + r'"\u003c/script\u003e\u003cscript\u003e' + r'alert(\"gotcha\")\u003c/script\u003e"', + self.encoder.encode(bad_string)) + self.assertEqual( + bad_string, self.decoder.decode( + self.encoder.encode(bad_string))) diff --git a/lib/simplejson/tests/test_errors.py b/lib/simplejson/tests/test_errors.py new file mode 100644 index 0000000..c6e8688 --- /dev/null +++ b/lib/simplejson/tests/test_errors.py @@ -0,0 +1,68 @@ +import sys, pickle +from unittest import TestCase + +import simplejson as json +from simplejson.compat import text_type, b + +class TestErrors(TestCase): + def test_string_keys_error(self): + data = [{'a': 'A', 'b': (2, 4), 'c': 3.0, ('d',): 'D tuple'}] + try: + json.dumps(data) + except TypeError: + err = sys.exc_info()[1] + else: + self.fail('Expected TypeError') + self.assertEqual(str(err), + 'keys must be str, int, float, bool or None, not tuple') + + def test_not_serializable(self): + try: + json.dumps(json) + except TypeError: + err = sys.exc_info()[1] + else: + self.fail('Expected TypeError') + self.assertEqual(str(err), + 'Object of type module is not JSON serializable') + + def test_decode_error(self): + err = None + try: + json.loads('{}\na\nb') + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + self.assertEqual(err.lineno, 2) + self.assertEqual(err.colno, 1) + self.assertEqual(err.endlineno, 3) + self.assertEqual(err.endcolno, 2) + + def test_scan_error(self): + err = None + for t in (text_type, b): + try: + json.loads(t('{"asdf": "')) + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, 10) + + def test_error_is_pickable(self): + err = None + try: + json.loads('{}\na\nb') + except json.JSONDecodeError: + err = sys.exc_info()[1] + else: + self.fail('Expected JSONDecodeError') + s = pickle.dumps(err) + e = pickle.loads(s) + + self.assertEqual(err.msg, e.msg) + self.assertEqual(err.doc, e.doc) + self.assertEqual(err.pos, e.pos) + self.assertEqual(err.end, e.end) diff --git a/lib/simplejson/tests/test_fail.py b/lib/simplejson/tests/test_fail.py new file mode 100644 index 0000000..54b7414 --- /dev/null +++ b/lib/simplejson/tests/test_fail.py @@ -0,0 +1,178 @@ +import sys +from unittest import TestCase + +import simplejson as json + +# 2007-10-05 +JSONDOCS = [ + # http://json.org/JSON_checker/test/fail1.json + '"A JSON payload should be an object or array, not a string."', + # http://json.org/JSON_checker/test/fail2.json + '["Unclosed array"', + # http://json.org/JSON_checker/test/fail3.json + '{unquoted_key: "keys must be quoted"}', + # http://json.org/JSON_checker/test/fail4.json + '["extra comma",]', + # http://json.org/JSON_checker/test/fail5.json + '["double extra comma",,]', + # http://json.org/JSON_checker/test/fail6.json + '[ , "<-- missing value"]', + # http://json.org/JSON_checker/test/fail7.json + '["Comma after the close"],', + # http://json.org/JSON_checker/test/fail8.json + '["Extra close"]]', + # http://json.org/JSON_checker/test/fail9.json + '{"Extra comma": true,}', + # http://json.org/JSON_checker/test/fail10.json + '{"Extra value after close": true} "misplaced quoted value"', + # http://json.org/JSON_checker/test/fail11.json + '{"Illegal expression": 1 + 2}', + # http://json.org/JSON_checker/test/fail12.json + '{"Illegal invocation": alert()}', + # http://json.org/JSON_checker/test/fail13.json + '{"Numbers cannot have leading zeroes": 013}', + # http://json.org/JSON_checker/test/fail14.json + '{"Numbers cannot be hex": 0x14}', + # http://json.org/JSON_checker/test/fail15.json + '["Illegal backslash escape: \\x15"]', + # http://json.org/JSON_checker/test/fail16.json + '[\\naked]', + # http://json.org/JSON_checker/test/fail17.json + '["Illegal backslash escape: \\017"]', + # http://json.org/JSON_checker/test/fail18.json + '[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', + # http://json.org/JSON_checker/test/fail19.json + '{"Missing colon" null}', + # http://json.org/JSON_checker/test/fail20.json + '{"Double colon":: null}', + # http://json.org/JSON_checker/test/fail21.json + '{"Comma instead of colon", null}', + # http://json.org/JSON_checker/test/fail22.json + '["Colon instead of comma": false]', + # http://json.org/JSON_checker/test/fail23.json + '["Bad value", truth]', + # http://json.org/JSON_checker/test/fail24.json + "['single quote']", + # http://json.org/JSON_checker/test/fail25.json + '["\ttab\tcharacter\tin\tstring\t"]', + # http://json.org/JSON_checker/test/fail26.json + '["tab\\ character\\ in\\ string\\ "]', + # http://json.org/JSON_checker/test/fail27.json + '["line\nbreak"]', + # http://json.org/JSON_checker/test/fail28.json + '["line\\\nbreak"]', + # http://json.org/JSON_checker/test/fail29.json + '[0e]', + # http://json.org/JSON_checker/test/fail30.json + '[0e+]', + # http://json.org/JSON_checker/test/fail31.json + '[0e+-1]', + # http://json.org/JSON_checker/test/fail32.json + '{"Comma instead if closing brace": true,', + # http://json.org/JSON_checker/test/fail33.json + '["mismatch"}', + # http://code.google.com/p/simplejson/issues/detail?id=3 + u'["A\u001FZ control characters in string"]', + # misc based on coverage + '{', + '{]', + '{"foo": "bar"]', + '{"foo": "bar"', + 'nul', + 'nulx', + '-', + '-x', + '-e', + '-e0', + '-Infinite', + '-Inf', + 'Infinit', + 'Infinite', + 'NaM', + 'NuN', + 'falsy', + 'fal', + 'trug', + 'tru', + '1e', + '1ex', + '1e-', + '1e-x', +] + +SKIPS = { + 1: "why not have a string payload?", + 18: "spec doesn't specify any nesting limitations", +} + +class TestFail(TestCase): + def test_failures(self): + for idx, doc in enumerate(JSONDOCS): + idx = idx + 1 + if idx in SKIPS: + json.loads(doc) + continue + try: + json.loads(doc) + except json.JSONDecodeError: + pass + else: + self.fail("Expected failure for fail%d.json: %r" % (idx, doc)) + + def test_array_decoder_issue46(self): + # http://code.google.com/p/simplejson/issues/detail?id=46 + for doc in [u'[,]', '[,]']: + try: + json.loads(doc) + except json.JSONDecodeError: + e = sys.exc_info()[1] + self.assertEqual(e.pos, 1) + self.assertEqual(e.lineno, 1) + self.assertEqual(e.colno, 2) + except Exception: + e = sys.exc_info()[1] + self.fail("Unexpected exception raised %r %s" % (e, e)) + else: + self.fail("Unexpected success parsing '[,]'") + + def test_truncated_input(self): + test_cases = [ + ('', 'Expecting value', 0), + ('[', "Expecting value or ']'", 1), + ('[42', "Expecting ',' delimiter", 3), + ('[42,', 'Expecting value', 4), + ('["', 'Unterminated string starting at', 1), + ('["spam', 'Unterminated string starting at', 1), + ('["spam"', "Expecting ',' delimiter", 7), + ('["spam",', 'Expecting value', 8), + ('{', "Expecting property name enclosed in double quotes or '}'", 1), + ('{"', 'Unterminated string starting at', 1), + ('{"spam', 'Unterminated string starting at', 1), + ('{"spam"', "Expecting ':' delimiter", 7), + ('{"spam":', 'Expecting value', 8), + ('{"spam":42', "Expecting ',' delimiter", 10), + ('{"spam":42,', 'Expecting property name enclosed in double quotes', + 11), + ('"', 'Unterminated string starting at', 0), + ('"spam', 'Unterminated string starting at', 0), + ('[,', "Expecting value", 1), + ('--', 'Expecting value', 0), + ('"\x18d', "Invalid control character %r", 1), + ] + for data, msg, idx in test_cases: + try: + json.loads(data) + except json.JSONDecodeError: + e = sys.exc_info()[1] + self.assertEqual( + e.msg[:len(msg)], + msg, + "%r doesn't start with %r for %r" % (e.msg, msg, data)) + self.assertEqual( + e.pos, idx, + "pos %r != %r for %r" % (e.pos, idx, data)) + except Exception: + e = sys.exc_info()[1] + self.fail("Unexpected exception raised %r %s" % (e, e)) + else: + self.fail("Unexpected success parsing '%r'" % (data,)) diff --git a/lib/simplejson/tests/test_float.py b/lib/simplejson/tests/test_float.py new file mode 100644 index 0000000..494d999 --- /dev/null +++ b/lib/simplejson/tests/test_float.py @@ -0,0 +1,38 @@ +import math +from unittest import TestCase +from simplejson.compat import long_type, text_type +import simplejson as json +from simplejson.decoder import NaN, PosInf, NegInf + +class TestFloat(TestCase): + def test_degenerates_allow(self): + for inf in (PosInf, NegInf): + self.assertEqual(json.loads(json.dumps(inf, allow_nan=True), allow_nan=True), inf) + # Python 2.5 doesn't have math.isnan + nan = json.loads(json.dumps(NaN, allow_nan=True), allow_nan=True) + self.assertTrue((0 + nan) != nan) + + def test_degenerates_ignore(self): + for f in (PosInf, NegInf, NaN): + self.assertEqual(json.loads(json.dumps(f, ignore_nan=True)), None) + + def test_degenerates_deny(self): + for f in (PosInf, NegInf, NaN): + self.assertRaises(ValueError, json.dumps, f, allow_nan=False) + for s in ('Infinity', '-Infinity', 'NaN'): + self.assertRaises(ValueError, json.loads, s, allow_nan=False) + self.assertRaises(ValueError, json.loads, s) + + def test_floats(self): + for num in [1617161771.7650001, math.pi, math.pi**100, + math.pi**-100, 3.1]: + self.assertEqual(float(json.dumps(num)), num) + self.assertEqual(json.loads(json.dumps(num)), num) + self.assertEqual(json.loads(text_type(json.dumps(num))), num) + + def test_ints(self): + for num in [1, long_type(1), 1<<32, 1<<64]: + self.assertEqual(json.dumps(num), str(num)) + self.assertEqual(int(json.dumps(num)), num) + self.assertEqual(json.loads(json.dumps(num)), num) + self.assertEqual(json.loads(text_type(json.dumps(num))), num) diff --git a/lib/simplejson/tests/test_for_json.py b/lib/simplejson/tests/test_for_json.py new file mode 100644 index 0000000..4c153fd --- /dev/null +++ b/lib/simplejson/tests/test_for_json.py @@ -0,0 +1,97 @@ +import unittest +import simplejson as json + + +class ForJson(object): + def for_json(self): + return {'for_json': 1} + + +class NestedForJson(object): + def for_json(self): + return {'nested': ForJson()} + + +class ForJsonList(object): + def for_json(self): + return ['list'] + + +class DictForJson(dict): + def for_json(self): + return {'alpha': 1} + + +class ListForJson(list): + def for_json(self): + return ['list'] + + +class TestForJson(unittest.TestCase): + def assertRoundTrip(self, obj, other, for_json=True): + if for_json is None: + # None will use the default + s = json.dumps(obj) + else: + s = json.dumps(obj, for_json=for_json) + self.assertEqual( + json.loads(s), + other) + + def test_for_json_encodes_stand_alone_object(self): + self.assertRoundTrip( + ForJson(), + ForJson().for_json()) + + def test_for_json_encodes_object_nested_in_dict(self): + self.assertRoundTrip( + {'hooray': ForJson()}, + {'hooray': ForJson().for_json()}) + + def test_for_json_encodes_object_nested_in_list_within_dict(self): + self.assertRoundTrip( + {'list': [0, ForJson(), 2, 3]}, + {'list': [0, ForJson().for_json(), 2, 3]}) + + def test_for_json_encodes_object_nested_within_object(self): + self.assertRoundTrip( + NestedForJson(), + {'nested': {'for_json': 1}}) + + def test_for_json_encodes_list(self): + self.assertRoundTrip( + ForJsonList(), + ForJsonList().for_json()) + + def test_for_json_encodes_list_within_object(self): + self.assertRoundTrip( + {'nested': ForJsonList()}, + {'nested': ForJsonList().for_json()}) + + def test_for_json_encodes_dict_subclass(self): + self.assertRoundTrip( + DictForJson(a=1), + DictForJson(a=1).for_json()) + + def test_for_json_encodes_list_subclass(self): + self.assertRoundTrip( + ListForJson(['l']), + ListForJson(['l']).for_json()) + + def test_for_json_ignored_if_not_true_with_dict_subclass(self): + for for_json in (None, False): + self.assertRoundTrip( + DictForJson(a=1), + {'a': 1}, + for_json=for_json) + + def test_for_json_ignored_if_not_true_with_list_subclass(self): + for for_json in (None, False): + self.assertRoundTrip( + ListForJson(['l']), + ['l'], + for_json=for_json) + + def test_raises_typeerror_if_for_json_not_true_with_object(self): + self.assertRaises(TypeError, json.dumps, ForJson()) + self.assertRaises(TypeError, json.dumps, ForJson(), for_json=False) diff --git a/lib/simplejson/tests/test_indent.py b/lib/simplejson/tests/test_indent.py new file mode 100644 index 0000000..32326a6 --- /dev/null +++ b/lib/simplejson/tests/test_indent.py @@ -0,0 +1,86 @@ +from unittest import TestCase +import textwrap + +import simplejson as json +from simplejson.compat import StringIO + +class TestIndent(TestCase): + def test_indent(self): + h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', + 'i-vhbjkhnth', + {'nifty': 87}, {'field': 'yes', 'morefield': False} ] + + expect = textwrap.dedent("""\ + [ + \t[ + \t\t"blorpie" + \t], + \t[ + \t\t"whoops" + \t], + \t[], + \t"d-shtaeou", + \t"d-nthiouh", + \t"i-vhbjkhnth", + \t{ + \t\t"nifty": 87 + \t}, + \t{ + \t\t"field": "yes", + \t\t"morefield": false + \t} + ]""") + + + d1 = json.dumps(h) + d2 = json.dumps(h, indent='\t', sort_keys=True, separators=(',', ': ')) + d3 = json.dumps(h, indent=' ', sort_keys=True, separators=(',', ': ')) + d4 = json.dumps(h, indent=2, sort_keys=True, separators=(',', ': ')) + + h1 = json.loads(d1) + h2 = json.loads(d2) + h3 = json.loads(d3) + h4 = json.loads(d4) + + self.assertEqual(h1, h) + self.assertEqual(h2, h) + self.assertEqual(h3, h) + self.assertEqual(h4, h) + self.assertEqual(d3, expect.replace('\t', ' ')) + self.assertEqual(d4, expect.replace('\t', ' ')) + # NOTE: Python 2.4 textwrap.dedent converts tabs to spaces, + # so the following is expected to fail. Python 2.4 is not a + # supported platform in simplejson 2.1.0+. + self.assertEqual(d2, expect) + + def test_indent0(self): + h = {3: 1} + def check(indent, expected): + d1 = json.dumps(h, indent=indent) + self.assertEqual(d1, expected) + + sio = StringIO() + json.dump(h, sio, indent=indent) + self.assertEqual(sio.getvalue(), expected) + + # indent=0 should emit newlines + check(0, '{\n"3": 1\n}') + # indent=None is more compact + check(None, '{"3": 1}') + + def test_separators(self): + lst = [1,2,3,4] + expect = '[\n1,\n2,\n3,\n4\n]' + expect_spaces = '[\n1, \n2, \n3, \n4\n]' + # Ensure that separators still works + self.assertEqual( + expect_spaces, + json.dumps(lst, indent=0, separators=(', ', ': '))) + # Force the new defaults + self.assertEqual( + expect, + json.dumps(lst, indent=0, separators=(',', ': '))) + # Added in 2.1.4 + self.assertEqual( + expect, + json.dumps(lst, indent=0)) diff --git a/lib/simplejson/tests/test_item_sort_key.py b/lib/simplejson/tests/test_item_sort_key.py new file mode 100644 index 0000000..c913304 --- /dev/null +++ b/lib/simplejson/tests/test_item_sort_key.py @@ -0,0 +1,27 @@ +from unittest import TestCase + +import simplejson as json +from operator import itemgetter + +class TestItemSortKey(TestCase): + def test_simple_first(self): + a = {'a': 1, 'c': 5, 'jack': 'jill', 'pick': 'axe', 'array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} + self.assertEqual( + '{"a": 1, "c": 5, "crate": "dog", "jack": "jill", "pick": "axe", "zeak": "oh", "array": [1, 5, 6, 9], "tuple": [83, 12, 3]}', + json.dumps(a, item_sort_key=json.simple_first)) + + def test_case(self): + a = {'a': 1, 'c': 5, 'Jack': 'jill', 'pick': 'axe', 'Array': [1, 5, 6, 9], 'tuple': (83, 12, 3), 'crate': 'dog', 'zeak': 'oh'} + self.assertEqual( + '{"Array": [1, 5, 6, 9], "Jack": "jill", "a": 1, "c": 5, "crate": "dog", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', + json.dumps(a, item_sort_key=itemgetter(0))) + self.assertEqual( + '{"a": 1, "Array": [1, 5, 6, 9], "c": 5, "crate": "dog", "Jack": "jill", "pick": "axe", "tuple": [83, 12, 3], "zeak": "oh"}', + json.dumps(a, item_sort_key=lambda kv: kv[0].lower())) + + def test_item_sort_key_value(self): + # https://github.com/simplejson/simplejson/issues/173 + a = {'a': 1, 'b': 0} + self.assertEqual( + '{"b": 0, "a": 1}', + json.dumps(a, item_sort_key=lambda kv: kv[1])) diff --git a/lib/simplejson/tests/test_iterable.py b/lib/simplejson/tests/test_iterable.py new file mode 100644 index 0000000..995e1fd --- /dev/null +++ b/lib/simplejson/tests/test_iterable.py @@ -0,0 +1,31 @@ +import unittest +from simplejson.compat import StringIO + +import simplejson as json + +def iter_dumps(obj, **kw): + return ''.join(json.JSONEncoder(**kw).iterencode(obj)) + +def sio_dump(obj, **kw): + sio = StringIO() + json.dumps(obj, **kw) + return sio.getvalue() + +class TestIterable(unittest.TestCase): + def test_iterable(self): + for l in ([], [1], [1, 2], [1, 2, 3]): + for opts in [{}, {'indent': 2}]: + for dumps in (json.dumps, iter_dumps, sio_dump): + expect = dumps(l, **opts) + default_expect = dumps(sum(l), **opts) + # Default is False + self.assertRaises(TypeError, dumps, iter(l), **opts) + self.assertRaises(TypeError, dumps, iter(l), iterable_as_array=False, **opts) + self.assertEqual(expect, dumps(iter(l), iterable_as_array=True, **opts)) + # Ensure that the "default" gets called + self.assertEqual(default_expect, dumps(iter(l), default=sum, **opts)) + self.assertEqual(default_expect, dumps(iter(l), iterable_as_array=False, default=sum, **opts)) + # Ensure that the "default" does not get called + self.assertEqual( + expect, + dumps(iter(l), iterable_as_array=True, default=sum, **opts)) diff --git a/lib/simplejson/tests/test_namedtuple.py b/lib/simplejson/tests/test_namedtuple.py new file mode 100644 index 0000000..e3aa310 --- /dev/null +++ b/lib/simplejson/tests/test_namedtuple.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +import unittest +import simplejson as json +from simplejson.compat import StringIO + +try: + from unittest import mock +except ImportError: + mock = None + +try: + from collections import namedtuple +except ImportError: + class Value(tuple): + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def _asdict(self): + return {'value': self[0]} + class Point(tuple): + def __new__(cls, *args): + return tuple.__new__(cls, args) + + def _asdict(self): + return {'x': self[0], 'y': self[1]} +else: + Value = namedtuple('Value', ['value']) + Point = namedtuple('Point', ['x', 'y']) + +class DuckValue(object): + def __init__(self, *args): + self.value = Value(*args) + + def _asdict(self): + return self.value._asdict() + +class DuckPoint(object): + def __init__(self, *args): + self.point = Point(*args) + + def _asdict(self): + return self.point._asdict() + +class DeadDuck(object): + _asdict = None + +class DeadDict(dict): + _asdict = None + +CONSTRUCTORS = [ + lambda v: v, + lambda v: [v], + lambda v: [{'key': v}], +] + +class TestNamedTuple(unittest.TestCase): + def test_namedtuple_dumps(self): + for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: + d = v._asdict() + self.assertEqual(d, json.loads(json.dumps(v))) + self.assertEqual( + d, + json.loads(json.dumps(v, namedtuple_as_object=True))) + self.assertEqual(d, json.loads(json.dumps(v, tuple_as_array=False))) + self.assertEqual( + d, + json.loads(json.dumps(v, namedtuple_as_object=True, + tuple_as_array=False))) + + def test_namedtuple_dumps_false(self): + for v in [Value(1), Point(1, 2)]: + l = list(v) + self.assertEqual( + l, + json.loads(json.dumps(v, namedtuple_as_object=False))) + self.assertRaises(TypeError, json.dumps, v, + tuple_as_array=False, namedtuple_as_object=False) + + def test_namedtuple_dump(self): + for v in [Value(1), Point(1, 2), DuckValue(1), DuckPoint(1, 2)]: + d = v._asdict() + sio = StringIO() + json.dump(v, sio) + self.assertEqual(d, json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=True) + self.assertEqual( + d, + json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, tuple_as_array=False) + self.assertEqual(d, json.loads(sio.getvalue())) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=True, + tuple_as_array=False) + self.assertEqual( + d, + json.loads(sio.getvalue())) + + def test_namedtuple_dump_false(self): + for v in [Value(1), Point(1, 2)]: + l = list(v) + sio = StringIO() + json.dump(v, sio, namedtuple_as_object=False) + self.assertEqual( + l, + json.loads(sio.getvalue())) + self.assertRaises(TypeError, json.dump, v, StringIO(), + tuple_as_array=False, namedtuple_as_object=False) + + def test_asdict_not_callable_dump(self): + for f in CONSTRUCTORS: + self.assertRaises( + TypeError, + json.dump, + f(DeadDuck()), + StringIO(), + namedtuple_as_object=True + ) + sio = StringIO() + json.dump(f(DeadDict()), sio, namedtuple_as_object=True) + self.assertEqual( + json.dumps(f({})), + sio.getvalue()) + self.assertRaises( + TypeError, + json.dump, + f(Value), + StringIO(), + namedtuple_as_object=True + ) + + def test_asdict_not_callable_dumps(self): + for f in CONSTRUCTORS: + self.assertRaises(TypeError, + json.dumps, f(DeadDuck()), namedtuple_as_object=True) + self.assertRaises( + TypeError, + json.dumps, + f(Value), + namedtuple_as_object=True + ) + self.assertEqual( + json.dumps(f({})), + json.dumps(f(DeadDict()), namedtuple_as_object=True)) + + def test_asdict_unbound_method_dumps(self): + for f in CONSTRUCTORS: + self.assertEqual( + json.dumps(f(Value), default=lambda v: v.__name__), + json.dumps(f(Value.__name__)) + ) + + def test_asdict_does_not_return_dict(self): + if not mock: + if hasattr(unittest, "SkipTest"): + raise unittest.SkipTest("unittest.mock required") + else: + print("unittest.mock not available") + return + fake = mock.Mock() + self.assertTrue(hasattr(fake, '_asdict')) + self.assertTrue(callable(fake._asdict)) + self.assertFalse(isinstance(fake._asdict(), dict)) + # https://github.com/simplejson/simplejson/pull/284 + # when running under a debug build of CPython (COPTS=-UNDEBUG) + # a C assertion could fire due to an unchecked error of an PyDict + # API call on a non-dict internally in _speedups.c. Without a debug + # build of CPython this test likely passes either way despite the + # potential for internal data corruption. Getting it to crash in + # a debug build is not always easy either as it requires an + # assert(!PyErr_Occurred()) that could fire later on. + with self.assertRaises(TypeError): + json.dumps({23: fake}, namedtuple_as_object=True, for_json=False) diff --git a/lib/simplejson/tests/test_pass1.py b/lib/simplejson/tests/test_pass1.py new file mode 100644 index 0000000..7482833 --- /dev/null +++ b/lib/simplejson/tests/test_pass1.py @@ -0,0 +1,71 @@ +from unittest import TestCase + +import simplejson as json + +# from http://json.org/JSON_checker/test/pass1.json +JSON = r''' +[ + "JSON Test Pattern pass1", + {"object with 1 member":["array with 1 element"]}, + {}, + [], + -42, + true, + false, + null, + { + "integer": 1234567890, + "real": -9876.543210, + "e": 0.123456789e-12, + "E": 1.234567890E+34, + "": 23456789012E66, + "zero": 0, + "one": 1, + "space": " ", + "quote": "\"", + "backslash": "\\", + "controls": "\b\f\n\r\t", + "slash": "/ & \/", + "alpha": "abcdefghijklmnopqrstuvwyz", + "ALPHA": "ABCDEFGHIJKLMNOPQRSTUVWYZ", + "digit": "0123456789", + "special": "`1~!@#$%^&*()_+-={':[,]}|;.?", + "hex": "\u0123\u4567\u89AB\uCDEF\uabcd\uef4A", + "true": true, + "false": false, + "null": null, + "array":[ ], + "object":{ }, + "address": "50 St. James Street", + "url": "http://www.JSON.org/", + "comment": "// /* */": " ", + " s p a c e d " :[1,2 , 3 + +, + +4 , 5 , 6 ,7 ],"compact": [1,2,3,4,5,6,7], + "jsontext": "{\"object with 1 member\":[\"array with 1 element\"]}", + "quotes": "" \u0022 %22 0x22 034 "", + "\/\\\"\uCAFE\uBABE\uAB98\uFCDE\ubcda\uef4A\b\f\n\r\t`1~!@#$%^&*()_+-=[]{}|;:',./<>?" +: "A key can be any string" + }, + 0.5 ,98.6 +, +99.44 +, + +1066, +1e1, +0.1e1, +1e-1, +1e00,2e+00,2e-00 +,"rosebud"] +''' + +class TestPass1(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_pass2.py b/lib/simplejson/tests/test_pass2.py new file mode 100644 index 0000000..5c8e9ef --- /dev/null +++ b/lib/simplejson/tests/test_pass2.py @@ -0,0 +1,14 @@ +from unittest import TestCase +import simplejson as json + +# from http://json.org/JSON_checker/test/pass2.json +JSON = r''' +[[[[[[[[[[[[[[[[[[["Not too deep"]]]]]]]]]]]]]]]]]]] +''' + +class TestPass2(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_pass3.py b/lib/simplejson/tests/test_pass3.py new file mode 100644 index 0000000..7d9cd8c --- /dev/null +++ b/lib/simplejson/tests/test_pass3.py @@ -0,0 +1,20 @@ +from unittest import TestCase + +import simplejson as json + +# from http://json.org/JSON_checker/test/pass3.json +JSON = r''' +{ + "JSON Test Pattern pass3": { + "The outermost value": "must be an object or array.", + "In this test": "It is an object." + } +} +''' + +class TestPass3(TestCase): + def test_parse(self): + # test in/out equivalence and parsing + res = json.loads(JSON) + out = json.dumps(res) + self.assertEqual(res, json.loads(out)) diff --git a/lib/simplejson/tests/test_raw_json.py b/lib/simplejson/tests/test_raw_json.py new file mode 100644 index 0000000..aaa2724 --- /dev/null +++ b/lib/simplejson/tests/test_raw_json.py @@ -0,0 +1,47 @@ +import unittest +import simplejson as json + +dct1 = { + 'key1': 'value1' +} + +dct2 = { + 'key2': 'value2', + 'd1': dct1 +} + +dct3 = { + 'key2': 'value2', + 'd1': json.dumps(dct1) +} + +dct4 = { + 'key2': 'value2', + 'd1': json.RawJSON(json.dumps(dct1)) +} + + +class TestRawJson(unittest.TestCase): + + def test_normal_str(self): + self.assertNotEqual(json.dumps(dct2), json.dumps(dct3)) + + def test_raw_json_str(self): + self.assertEqual(json.dumps(dct2), json.dumps(dct4)) + self.assertEqual(dct2, json.loads(json.dumps(dct4))) + + def test_list(self): + self.assertEqual( + json.dumps([dct2]), + json.dumps([json.RawJSON(json.dumps(dct2))])) + self.assertEqual( + [dct2], + json.loads(json.dumps([json.RawJSON(json.dumps(dct2))]))) + + def test_direct(self): + self.assertEqual( + json.dumps(dct2), + json.dumps(json.RawJSON(json.dumps(dct2)))) + self.assertEqual( + dct2, + json.loads(json.dumps(json.RawJSON(json.dumps(dct2))))) diff --git a/lib/simplejson/tests/test_recursion.py b/lib/simplejson/tests/test_recursion.py new file mode 100644 index 0000000..76328fc --- /dev/null +++ b/lib/simplejson/tests/test_recursion.py @@ -0,0 +1,67 @@ +from unittest import TestCase + +import simplejson as json + +class JSONTestObject: + pass + + +class RecursiveJSONEncoder(json.JSONEncoder): + recurse = False + def default(self, o): + if o is JSONTestObject: + if self.recurse: + return [JSONTestObject] + else: + return 'JSONTestObject' + return json.JSONEncoder.default(o) + + +class TestRecursion(TestCase): + def test_listrecursion(self): + x = [] + x.append(x) + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on list recursion") + x = [] + y = [x] + x.append(y) + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on alternating list recursion") + y = [] + x = [y, y] + # ensure that the marker is cleared + json.dumps(x) + + def test_dictrecursion(self): + x = {} + x["test"] = x + try: + json.dumps(x) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on dict recursion") + x = {} + y = {"a": x, "b": x} + # ensure that the marker is cleared + json.dumps(y) + + def test_defaultrecursion(self): + enc = RecursiveJSONEncoder() + self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"') + enc.recurse = True + try: + enc.encode(JSONTestObject) + except ValueError: + pass + else: + self.fail("didn't raise ValueError on default recursion") diff --git a/lib/simplejson/tests/test_scanstring.py b/lib/simplejson/tests/test_scanstring.py new file mode 100644 index 0000000..89d83d6 --- /dev/null +++ b/lib/simplejson/tests/test_scanstring.py @@ -0,0 +1,200 @@ +import sys +from unittest import TestCase + +import simplejson as json +import simplejson.decoder +from simplejson.compat import b, PY3 + +class TestScanString(TestCase): + # The bytes type is intentionally not used in most of these tests + # under Python 3 because the decoder immediately coerces to str before + # calling scanstring. In Python 2 we are testing the code paths + # for both unicode and str. + # + # The reason this is done is because Python 3 would require + # entirely different code paths for parsing bytes and str. + # + def test_py_scanstring(self): + self._test_scanstring(simplejson.decoder.py_scanstring) + + def test_c_scanstring(self): + if not simplejson.decoder.c_scanstring: + return + self._test_scanstring(simplejson.decoder.c_scanstring) + + self.assertTrue(isinstance(simplejson.decoder.c_scanstring('""', 0)[0], str)) + + def _test_scanstring(self, scanstring): + if sys.maxunicode == 65535: + self.assertEqual( + scanstring(u'"z\U0001d120x"', 1, None, True), + (u'z\U0001d120x', 6)) + else: + self.assertEqual( + scanstring(u'"z\U0001d120x"', 1, None, True), + (u'z\U0001d120x', 5)) + + self.assertEqual( + scanstring('"\\u007b"', 1, None, True), + (u'{', 8)) + + self.assertEqual( + scanstring('"A JSON payload should be an object or array, not a string."', 1, None, True), + (u'A JSON payload should be an object or array, not a string.', 60)) + + self.assertEqual( + scanstring('["Unclosed array"', 2, None, True), + (u'Unclosed array', 17)) + + self.assertEqual( + scanstring('["extra comma",]', 2, None, True), + (u'extra comma', 14)) + + self.assertEqual( + scanstring('["double extra comma",,]', 2, None, True), + (u'double extra comma', 21)) + + self.assertEqual( + scanstring('["Comma after the close"],', 2, None, True), + (u'Comma after the close', 24)) + + self.assertEqual( + scanstring('["Extra close"]]', 2, None, True), + (u'Extra close', 14)) + + self.assertEqual( + scanstring('{"Extra comma": true,}', 2, None, True), + (u'Extra comma', 14)) + + self.assertEqual( + scanstring('{"Extra value after close": true} "misplaced quoted value"', 2, None, True), + (u'Extra value after close', 26)) + + self.assertEqual( + scanstring('{"Illegal expression": 1 + 2}', 2, None, True), + (u'Illegal expression', 21)) + + self.assertEqual( + scanstring('{"Illegal invocation": alert()}', 2, None, True), + (u'Illegal invocation', 21)) + + self.assertEqual( + scanstring('{"Numbers cannot have leading zeroes": 013}', 2, None, True), + (u'Numbers cannot have leading zeroes', 37)) + + self.assertEqual( + scanstring('{"Numbers cannot be hex": 0x14}', 2, None, True), + (u'Numbers cannot be hex', 24)) + + self.assertEqual( + scanstring('[[[[[[[[[[[[[[[[[[[["Too deep"]]]]]]]]]]]]]]]]]]]]', 21, None, True), + (u'Too deep', 30)) + + self.assertEqual( + scanstring('{"Missing colon" null}', 2, None, True), + (u'Missing colon', 16)) + + self.assertEqual( + scanstring('{"Double colon":: null}', 2, None, True), + (u'Double colon', 15)) + + self.assertEqual( + scanstring('{"Comma instead of colon", null}', 2, None, True), + (u'Comma instead of colon', 25)) + + self.assertEqual( + scanstring('["Colon instead of comma": false]', 2, None, True), + (u'Colon instead of comma', 25)) + + self.assertEqual( + scanstring('["Bad value", truth]', 2, None, True), + (u'Bad value', 12)) + + for c in map(chr, range(0x00, 0x1f)): + self.assertEqual( + scanstring(c + '"', 0, None, False), + (c, 2)) + self.assertRaises( + ValueError, + scanstring, c + '"', 0, None, True) + + self.assertRaises(ValueError, scanstring, '', 0, None, True) + self.assertRaises(ValueError, scanstring, 'a', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u0', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u01', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u012', 0, None, True) + self.assertRaises(ValueError, scanstring, '\\u0123', 0, None, True) + if sys.maxunicode > 65535: + self.assertRaises(ValueError, + scanstring, '\\ud834\\u"', 0, None, True) + self.assertRaises(ValueError, + scanstring, '\\ud834\\x0123"', 0, None, True) + + self.assertRaises(json.JSONDecodeError, scanstring, '\\u-123"', 0, None, True) + # SJ-PT-23-01: Invalid Handling of Broken Unicode Escape Sequences + self.assertRaises(json.JSONDecodeError, scanstring, '\\u EDD"', 0, None, True) + + def test_issue3623(self): + self.assertRaises(ValueError, json.decoder.scanstring, "xxx", 1, + "xxx") + self.assertRaises(UnicodeDecodeError, + json.encoder.encode_basestring_ascii, b("xx\xff")) + + def test_overflow(self): + # Python 2.5 does not have maxsize, Python 3 does not have maxint + maxsize = getattr(sys, 'maxsize', getattr(sys, 'maxint', None)) + assert maxsize is not None + self.assertRaises(OverflowError, json.decoder.scanstring, "xxx", + maxsize + 1) + + def test_surrogates(self): + scanstring = json.decoder.scanstring + + def assertScan(given, expect, test_utf8=True): + givens = [given] + if not PY3 and test_utf8: + givens.append(given.encode('utf8')) + for given in givens: + (res, count) = scanstring(given, 1, None, True) + self.assertEqual(len(given), count) + self.assertEqual(res, expect) + + assertScan( + u'"z\\ud834\\u0079x"', + u'z\ud834yx') + assertScan( + u'"z\\ud834\\udd20x"', + u'z\U0001d120x') + assertScan( + u'"z\\ud834\\ud834\\udd20x"', + u'z\ud834\U0001d120x') + assertScan( + u'"z\\ud834x"', + u'z\ud834x') + assertScan( + u'"z\\udd20x"', + u'z\udd20x') + assertScan( + u'"z\ud834x"', + u'z\ud834x') + # It may look strange to join strings together, but Python is drunk. + # https://gist.github.com/etrepum/5538443 + assertScan( + u'"z\\ud834\udd20x12345"', + u''.join([u'z\ud834', u'\udd20x12345'])) + assertScan( + u'"z\ud834\\udd20x"', + u''.join([u'z\ud834', u'\udd20x'])) + # these have different behavior given UTF8 input, because the surrogate + # pair may be joined (in maxunicode > 65535 builds) + assertScan( + u''.join([u'"z\ud834', u'\udd20x"']), + u''.join([u'z\ud834', u'\udd20x']), + test_utf8=False) + + self.assertRaises(ValueError, + scanstring, u'"z\\ud83x"', 1, None, True) + self.assertRaises(ValueError, + scanstring, u'"z\\ud834\\udd2x"', 1, None, True) diff --git a/lib/simplejson/tests/test_separators.py b/lib/simplejson/tests/test_separators.py new file mode 100644 index 0000000..5a21176 --- /dev/null +++ b/lib/simplejson/tests/test_separators.py @@ -0,0 +1,42 @@ +import textwrap +from unittest import TestCase + +import simplejson as json + + +class TestSeparators(TestCase): + def test_separators(self): + h = [['blorpie'], ['whoops'], [], 'd-shtaeou', 'd-nthiouh', 'i-vhbjkhnth', + {'nifty': 87}, {'field': 'yes', 'morefield': False} ] + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ] , + [ + "whoops" + ] , + [] , + "d-shtaeou" , + "d-nthiouh" , + "i-vhbjkhnth" , + { + "nifty" : 87 + } , + { + "field" : "yes" , + "morefield" : false + } + ]""") + + + d1 = json.dumps(h) + d2 = json.dumps(h, indent=' ', sort_keys=True, separators=(' ,', ' : ')) + + h1 = json.loads(d1) + h2 = json.loads(d2) + + self.assertEqual(h1, h) + self.assertEqual(h2, h) + self.assertEqual(d2, expect) diff --git a/lib/simplejson/tests/test_speedups.py b/lib/simplejson/tests/test_speedups.py new file mode 100644 index 0000000..c0a66f3 --- /dev/null +++ b/lib/simplejson/tests/test_speedups.py @@ -0,0 +1,114 @@ +from __future__ import with_statement + +import sys +import unittest +from unittest import TestCase + +import simplejson +from simplejson import encoder, decoder, scanner +from simplejson.compat import PY3, long_type, b + + +def has_speedups(): + return encoder.c_make_encoder is not None + + +def skip_if_speedups_missing(func): + def wrapper(*args, **kwargs): + if not has_speedups(): + if hasattr(unittest, 'SkipTest'): + raise unittest.SkipTest("C Extension not available") + else: + sys.stdout.write("C Extension not available") + return + return func(*args, **kwargs) + + return wrapper + + +class BadBool: + def __bool__(self): + 1/0 + __nonzero__ = __bool__ + + +class TestDecode(TestCase): + @skip_if_speedups_missing + def test_make_scanner(self): + self.assertRaises(AttributeError, scanner.c_make_scanner, 1) + + @skip_if_speedups_missing + def test_bad_bool_args(self): + def test(value): + decoder.JSONDecoder(strict=BadBool()).decode(value) + self.assertRaises(ZeroDivisionError, test, '""') + self.assertRaises(ZeroDivisionError, test, '{}') + if not PY3: + self.assertRaises(ZeroDivisionError, test, u'""') + self.assertRaises(ZeroDivisionError, test, u'{}') + +class TestEncode(TestCase): + @skip_if_speedups_missing + def test_make_encoder(self): + self.assertRaises( + TypeError, + encoder.c_make_encoder, + None, + ("\xCD\x7D\x3D\x4E\x12\x4C\xF9\x79\xD7" + "\x52\xBA\x82\xF2\x27\x4A\x7D\xA0\xCA\x75"), + None + ) + + @skip_if_speedups_missing + def test_bad_str_encoder(self): + # Issue #31505: There shouldn't be an assertion failure in case + # c_make_encoder() receives a bad encoder() argument. + import decimal + def bad_encoder1(*args): + return None + enc = encoder.c_make_encoder( + None, lambda obj: str(obj), + bad_encoder1, None, ': ', ', ', + False, False, False, {}, False, False, False, + None, None, 'utf-8', False, False, decimal.Decimal, False) + self.assertRaises(TypeError, enc, 'spam', 4) + self.assertRaises(TypeError, enc, {'spam': 42}, 4) + + def bad_encoder2(*args): + 1/0 + enc = encoder.c_make_encoder( + None, lambda obj: str(obj), + bad_encoder2, None, ': ', ', ', + False, False, False, {}, False, False, False, + None, None, 'utf-8', False, False, decimal.Decimal, False) + self.assertRaises(ZeroDivisionError, enc, 'spam', 4) + + @skip_if_speedups_missing + def test_bad_bool_args(self): + def test(name): + encoder.JSONEncoder(**{name: BadBool()}).encode({}) + self.assertRaises(ZeroDivisionError, test, 'skipkeys') + self.assertRaises(ZeroDivisionError, test, 'ensure_ascii') + self.assertRaises(ZeroDivisionError, test, 'check_circular') + self.assertRaises(ZeroDivisionError, test, 'allow_nan') + self.assertRaises(ZeroDivisionError, test, 'sort_keys') + self.assertRaises(ZeroDivisionError, test, 'use_decimal') + self.assertRaises(ZeroDivisionError, test, 'namedtuple_as_object') + self.assertRaises(ZeroDivisionError, test, 'tuple_as_array') + self.assertRaises(ZeroDivisionError, test, 'bigint_as_string') + self.assertRaises(ZeroDivisionError, test, 'for_json') + self.assertRaises(ZeroDivisionError, test, 'ignore_nan') + self.assertRaises(ZeroDivisionError, test, 'iterable_as_array') + + @skip_if_speedups_missing + def test_int_as_string_bitcount_overflow(self): + long_count = long_type(2)**32+31 + def test(): + encoder.JSONEncoder(int_as_string_bitcount=long_count).encode(0) + self.assertRaises((TypeError, OverflowError), test) + + if PY3: + @skip_if_speedups_missing + def test_bad_encoding(self): + with self.assertRaises(UnicodeEncodeError): + encoder.JSONEncoder(encoding='\udcff').encode({b('key'): 123}) diff --git a/lib/simplejson/tests/test_str_subclass.py b/lib/simplejson/tests/test_str_subclass.py new file mode 100644 index 0000000..6bdd41f --- /dev/null +++ b/lib/simplejson/tests/test_str_subclass.py @@ -0,0 +1,21 @@ +from unittest import TestCase + +import simplejson +from simplejson.compat import text_type + +# Tests for issue demonstrated in https://github.com/simplejson/simplejson/issues/144 +class WonkyTextSubclass(text_type): + def __getslice__(self, start, end): + return self.__class__('not what you wanted!') + +class TestStrSubclass(TestCase): + def test_dump_load(self): + for s in ['', '"hello"', 'text', u'\u005c']: + self.assertEqual( + s, + simplejson.loads(simplejson.dumps(WonkyTextSubclass(s)))) + + self.assertEqual( + s, + simplejson.loads(simplejson.dumps(WonkyTextSubclass(s), + ensure_ascii=False))) diff --git a/lib/simplejson/tests/test_subclass.py b/lib/simplejson/tests/test_subclass.py new file mode 100644 index 0000000..a9ae318 --- /dev/null +++ b/lib/simplejson/tests/test_subclass.py @@ -0,0 +1,37 @@ +from unittest import TestCase +import simplejson as json + +from decimal import Decimal + +class AlternateInt(int): + def __repr__(self): + return 'invalid json' + __str__ = __repr__ + + +class AlternateFloat(float): + def __repr__(self): + return 'invalid json' + __str__ = __repr__ + + +# class AlternateDecimal(Decimal): +# def __repr__(self): +# return 'invalid json' + + +class TestSubclass(TestCase): + def test_int(self): + self.assertEqual(json.dumps(AlternateInt(1)), '1') + self.assertEqual(json.dumps(AlternateInt(-1)), '-1') + self.assertEqual(json.loads(json.dumps({AlternateInt(1): 1})), {'1': 1}) + + def test_float(self): + self.assertEqual(json.dumps(AlternateFloat(1.0)), '1.0') + self.assertEqual(json.dumps(AlternateFloat(-1.0)), '-1.0') + self.assertEqual(json.loads(json.dumps({AlternateFloat(1.0): 1})), {'1.0': 1}) + + # NOTE: Decimal subclasses are not supported as-is + # def test_decimal(self): + # self.assertEqual(json.dumps(AlternateDecimal('1.0')), '1.0') + # self.assertEqual(json.dumps(AlternateDecimal('-1.0')), '-1.0') diff --git a/lib/simplejson/tests/test_tool.py b/lib/simplejson/tests/test_tool.py new file mode 100644 index 0000000..59807ec --- /dev/null +++ b/lib/simplejson/tests/test_tool.py @@ -0,0 +1,114 @@ +from __future__ import with_statement +import os +import sys +import textwrap +import unittest +import subprocess +import tempfile +try: + # Python 3.x + from test.support import strip_python_stderr +except ImportError: + # Python 2.6+ + try: + from test.test_support import strip_python_stderr + except ImportError: + # Python 2.5 + import re + def strip_python_stderr(stderr): + return re.sub( + r"\[\d+ refs\]\r?\n?$".encode(), + "".encode(), + stderr).strip() + +def open_temp_file(): + if sys.version_info >= (2, 6): + file = tempfile.NamedTemporaryFile(delete=False) + filename = file.name + else: + fd, filename = tempfile.mkstemp() + file = os.fdopen(fd, 'w+b') + return file, filename + +class TestTool(unittest.TestCase): + data = """ + + [["blorpie"],[ "whoops" ] , [ + ],\t"d-shtaeou",\r"d-nthiouh", + "i-vhbjkhnth", {"nifty":87}, {"morefield" :\tfalse,"field" + :"yes"} ] + """ + + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ], + [ + "whoops" + ], + [], + "d-shtaeou", + "d-nthiouh", + "i-vhbjkhnth", + { + "nifty": 87 + }, + { + "field": "yes", + "morefield": false + } + ] + """) + + def runTool(self, args=None, data=None): + argv = [sys.executable, '-m', 'simplejson.tool'] + if args: + argv.extend(args) + proc = subprocess.Popen(argv, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE) + out, err = proc.communicate(data) + self.assertEqual(strip_python_stderr(err), ''.encode()) + self.assertEqual(proc.returncode, 0) + return out.decode('utf8').splitlines() + + def test_stdin_stdout(self): + self.assertEqual( + self.runTool(data=self.data.encode()), + self.expect.splitlines()) + + def test_infile_stdout(self): + infile, infile_name = open_temp_file() + try: + infile.write(self.data.encode()) + infile.close() + self.assertEqual( + self.runTool(args=[infile_name]), + self.expect.splitlines()) + finally: + os.unlink(infile_name) + + def test_infile_outfile(self): + infile, infile_name = open_temp_file() + try: + infile.write(self.data.encode()) + infile.close() + # outfile will get overwritten by tool, so the delete + # may not work on some platforms. Do it manually. + outfile, outfile_name = open_temp_file() + try: + outfile.close() + self.assertEqual( + self.runTool(args=[infile_name, outfile_name]), + []) + with open(outfile_name, 'rb') as f: + self.assertEqual( + f.read().decode('utf8').splitlines(), + self.expect.splitlines() + ) + finally: + os.unlink(outfile_name) + finally: + os.unlink(infile_name) diff --git a/lib/simplejson/tests/test_tuple.py b/lib/simplejson/tests/test_tuple.py new file mode 100644 index 0000000..94ea139 --- /dev/null +++ b/lib/simplejson/tests/test_tuple.py @@ -0,0 +1,47 @@ +import unittest + +from simplejson.compat import StringIO +import simplejson as json + +class TestTuples(unittest.TestCase): + def test_tuple_array_dumps(self): + t = (1, 2, 3) + expect = json.dumps(list(t)) + # Default is True + self.assertEqual(expect, json.dumps(t)) + self.assertEqual(expect, json.dumps(t, tuple_as_array=True)) + self.assertRaises(TypeError, json.dumps, t, tuple_as_array=False) + # Ensure that the "default" does not get called + self.assertEqual(expect, json.dumps(t, default=repr)) + self.assertEqual(expect, json.dumps(t, tuple_as_array=True, + default=repr)) + # Ensure that the "default" gets called + self.assertEqual( + json.dumps(repr(t)), + json.dumps(t, tuple_as_array=False, default=repr)) + + def test_tuple_array_dump(self): + t = (1, 2, 3) + expect = json.dumps(list(t)) + # Default is True + sio = StringIO() + json.dump(t, sio) + self.assertEqual(expect, sio.getvalue()) + sio = StringIO() + json.dump(t, sio, tuple_as_array=True) + self.assertEqual(expect, sio.getvalue()) + self.assertRaises(TypeError, json.dump, t, StringIO(), + tuple_as_array=False) + # Ensure that the "default" does not get called + sio = StringIO() + json.dump(t, sio, default=repr) + self.assertEqual(expect, sio.getvalue()) + sio = StringIO() + json.dump(t, sio, tuple_as_array=True, default=repr) + self.assertEqual(expect, sio.getvalue()) + # Ensure that the "default" gets called + sio = StringIO() + json.dump(t, sio, tuple_as_array=False, default=repr) + self.assertEqual( + json.dumps(repr(t)), + sio.getvalue()) diff --git a/lib/simplejson/tests/test_unicode.py b/lib/simplejson/tests/test_unicode.py new file mode 100644 index 0000000..823aa9d --- /dev/null +++ b/lib/simplejson/tests/test_unicode.py @@ -0,0 +1,154 @@ +import sys +import codecs +from unittest import TestCase + +import simplejson as json +from simplejson.compat import unichr, text_type, b, BytesIO + +class TestUnicode(TestCase): + def test_encoding1(self): + encoder = json.JSONEncoder(encoding='utf-8') + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + s = u.encode('utf-8') + ju = encoder.encode(u) + js = encoder.encode(s) + self.assertEqual(ju, js) + + def test_encoding2(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + s = u.encode('utf-8') + ju = json.dumps(u, encoding='utf-8') + js = json.dumps(s, encoding='utf-8') + self.assertEqual(ju, js) + + def test_encoding3(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps(u) + self.assertEqual(j, '"\\u03b1\\u03a9"') + + def test_encoding4(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps([u]) + self.assertEqual(j, '["\\u03b1\\u03a9"]') + + def test_encoding5(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps(u, ensure_ascii=False) + self.assertEqual(j, u'"' + u + u'"') + + def test_encoding6(self): + u = u'\N{GREEK SMALL LETTER ALPHA}\N{GREEK CAPITAL LETTER OMEGA}' + j = json.dumps([u], ensure_ascii=False) + self.assertEqual(j, u'["' + u + u'"]') + + def test_big_unicode_encode(self): + u = u'\U0001d120' + self.assertEqual(json.dumps(u), '"\\ud834\\udd20"') + self.assertEqual(json.dumps(u, ensure_ascii=False), u'"\U0001d120"') + + def test_big_unicode_decode(self): + u = u'z\U0001d120x' + self.assertEqual(json.loads('"' + u + '"'), u) + self.assertEqual(json.loads('"z\\ud834\\udd20x"'), u) + + def test_unicode_decode(self): + for i in range(0, 0xd7ff): + u = unichr(i) + #s = '"\\u{0:04x}"'.format(i) + s = '"\\u%04x"' % (i,) + self.assertEqual(json.loads(s), u) + + def test_object_pairs_hook_with_unicode(self): + s = u'{"xkd":1, "kcw":2, "art":3, "hxm":4, "qrt":5, "pad":6, "hoy":7}' + p = [(u"xkd", 1), (u"kcw", 2), (u"art", 3), (u"hxm", 4), + (u"qrt", 5), (u"pad", 6), (u"hoy", 7)] + self.assertEqual(json.loads(s), eval(s)) + self.assertEqual(json.loads(s, object_pairs_hook=lambda x: x), p) + od = json.loads(s, object_pairs_hook=json.OrderedDict) + self.assertEqual(od, json.OrderedDict(p)) + self.assertEqual(type(od), json.OrderedDict) + # the object_pairs_hook takes priority over the object_hook + self.assertEqual(json.loads(s, + object_pairs_hook=json.OrderedDict, + object_hook=lambda x: None), + json.OrderedDict(p)) + + + def test_default_encoding(self): + self.assertEqual(json.loads(u'{"a": "\xe9"}'.encode('utf-8')), + {'a': u'\xe9'}) + + def test_unicode_preservation(self): + self.assertEqual(type(json.loads(u'""')), text_type) + self.assertEqual(type(json.loads(u'"a"')), text_type) + self.assertEqual(type(json.loads(u'["a"]')[0]), text_type) + + def test_ensure_ascii_false_returns_unicode(self): + # http://code.google.com/p/simplejson/issues/detail?id=48 + self.assertEqual(type(json.dumps([], ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps(0, ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps({}, ensure_ascii=False)), text_type) + self.assertEqual(type(json.dumps("", ensure_ascii=False)), text_type) + + def test_ensure_ascii_false_bytestring_encoding(self): + # http://code.google.com/p/simplejson/issues/detail?id=48 + doc1 = {u'quux': b('Arr\xc3\xaat sur images')} + doc2 = {u'quux': u'Arr\xeat sur images'} + doc_ascii = '{"quux": "Arr\\u00eat sur images"}' + doc_unicode = u'{"quux": "Arr\xeat sur images"}' + self.assertEqual(json.dumps(doc1), doc_ascii) + self.assertEqual(json.dumps(doc2), doc_ascii) + self.assertEqual(json.dumps(doc1, ensure_ascii=False), doc_unicode) + self.assertEqual(json.dumps(doc2, ensure_ascii=False), doc_unicode) + + def test_ensure_ascii_linebreak_encoding(self): + # http://timelessrepo.com/json-isnt-a-javascript-subset + s1 = u'\u2029\u2028' + s2 = s1.encode('utf8') + expect = '"\\u2029\\u2028"' + expect_non_ascii = u'"\u2029\u2028"' + self.assertEqual(json.dumps(s1), expect) + self.assertEqual(json.dumps(s2), expect) + self.assertEqual(json.dumps(s1, ensure_ascii=False), expect_non_ascii) + self.assertEqual(json.dumps(s2, ensure_ascii=False), expect_non_ascii) + + def test_invalid_escape_sequences(self): + # incomplete escape sequence + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1234') + # invalid escape sequence + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u123x"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u12x4"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\u1x34"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ux234"') + if sys.maxunicode > 65535: + # invalid escape sequence for low surrogate + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u000x"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u00x0"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\u0x00"') + self.assertRaises(json.JSONDecodeError, json.loads, '"\\ud800\\ux000"') + + def test_ensure_ascii_still_works(self): + # in the ascii range, ensure that everything is the same + for c in map(unichr, range(0, 127)): + self.assertEqual( + json.dumps(c, ensure_ascii=False), + json.dumps(c)) + snowman = u'\N{SNOWMAN}' + self.assertEqual( + json.dumps(c, ensure_ascii=False), + '"' + c + '"') + + def test_strip_bom(self): + content = u"\u3053\u3093\u306b\u3061\u308f" + json_doc = codecs.BOM_UTF8 + b(json.dumps(content)) + self.assertEqual(json.load(BytesIO(json_doc)), content) + for doc in json_doc, json_doc.decode('utf8'): + self.assertEqual(json.loads(doc), content) diff --git a/lib/simplejson/tool.py b/lib/simplejson/tool.py new file mode 100644 index 0000000..c91a01d --- /dev/null +++ b/lib/simplejson/tool.py @@ -0,0 +1,42 @@ +r"""Command-line tool to validate and pretty-print JSON + +Usage:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 2 (char 2) + +""" +from __future__ import with_statement +import sys +import simplejson as json + +def main(): + if len(sys.argv) == 1: + infile = sys.stdin + outfile = sys.stdout + elif len(sys.argv) == 2: + infile = open(sys.argv[1], 'r') + outfile = sys.stdout + elif len(sys.argv) == 3: + infile = open(sys.argv[1], 'r') + outfile = open(sys.argv[2], 'w') + else: + raise SystemExit(sys.argv[0] + " [infile [outfile]]") + with infile: + try: + obj = json.load(infile, + object_pairs_hook=json.OrderedDict, + use_decimal=True) + except ValueError: + raise SystemExit(sys.exc_info()[1]) + with outfile: + json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) + outfile.write('\n') + + +if __name__ == '__main__': + main() From 34e2d36c2fb5076d37133d716a0bc7a5432c0a2a Mon Sep 17 00:00:00 2001 From: Isi <86603298+Isi-dev@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:36:15 +0100 Subject: [PATCH 3/5] Add files via upload --- test_func/save_targer_keys.py | 108 +++++ test_func/test_EndDec.py | 95 ++++ test_func/test_dataset.py | 152 +++++++ test_func/test_models.py | 56 +++ test_func/test_save_video.py | 24 + utils/__init__.py | 0 utils/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 190 bytes utils/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 134 bytes utils/__pycache__/assign_cfg.cpython-310.pyc | Bin 0 -> 1443 bytes utils/__pycache__/assign_cfg.cpython-39.pyc | Bin 0 -> 1661 bytes utils/__pycache__/config.cpython-310.pyc | Bin 0 -> 6934 bytes utils/__pycache__/config.cpython-39.pyc | Bin 0 -> 6798 bytes utils/__pycache__/distributed.cpython-310.pyc | Bin 0 -> 13026 bytes utils/__pycache__/distributed.cpython-39.pyc | Bin 0 -> 13630 bytes utils/__pycache__/logging.cpython-310.pyc | Bin 0 -> 2608 bytes utils/__pycache__/logging.cpython-39.pyc | Bin 0 -> 2522 bytes utils/__pycache__/multi_port.cpython-310.pyc | Bin 0 -> 663 bytes utils/__pycache__/multi_port.cpython-39.pyc | Bin 0 -> 605 bytes utils/__pycache__/registry.cpython-310.pyc | Bin 0 -> 5564 bytes utils/__pycache__/registry.cpython-39.pyc | Bin 0 -> 5584 bytes .../registry_class.cpython-310.pyc | Bin 0 -> 837 bytes .../__pycache__/registry_class.cpython-39.pyc | Bin 0 -> 781 bytes utils/__pycache__/seed.cpython-310.pyc | Bin 0 -> 498 bytes utils/__pycache__/seed.cpython-39.pyc | Bin 0 -> 442 bytes utils/__pycache__/transforms.cpython-310.pyc | Bin 0 -> 15264 bytes utils/__pycache__/transforms.cpython-39.pyc | Bin 0 -> 15998 bytes utils/__pycache__/video_op.cpython-310.pyc | Bin 0 -> 9287 bytes utils/__pycache__/video_op.cpython-39.pyc | Bin 0 -> 9445 bytes utils/assign_cfg.py | 78 ++++ utils/config.py | 243 ++++++++++ utils/distributed.py | 430 ++++++++++++++++++ utils/logging.py | 90 ++++ utils/mp4_to_gif.py | 16 + utils/multi_port.py | 9 + utils/optim/__init__.py | 2 + utils/optim/adafactor.py | 230 ++++++++++ utils/optim/lr_scheduler.py | 58 +++ utils/registry.py | 167 +++++++ utils/registry_class.py | 19 + utils/seed.py | 11 + utils/transforms.py | 353 ++++++++++++++ utils/util.py | 16 + utils/video_op.py | 359 +++++++++++++++ 43 files changed, 2516 insertions(+) create mode 100644 test_func/save_targer_keys.py create mode 100644 test_func/test_EndDec.py create mode 100644 test_func/test_dataset.py create mode 100644 test_func/test_models.py create mode 100644 test_func/test_save_video.py create mode 100644 utils/__init__.py create mode 100644 utils/__pycache__/__init__.cpython-310.pyc create mode 100644 utils/__pycache__/__init__.cpython-39.pyc create mode 100644 utils/__pycache__/assign_cfg.cpython-310.pyc create mode 100644 utils/__pycache__/assign_cfg.cpython-39.pyc create mode 100644 utils/__pycache__/config.cpython-310.pyc create mode 100644 utils/__pycache__/config.cpython-39.pyc create mode 100644 utils/__pycache__/distributed.cpython-310.pyc create mode 100644 utils/__pycache__/distributed.cpython-39.pyc create mode 100644 utils/__pycache__/logging.cpython-310.pyc create mode 100644 utils/__pycache__/logging.cpython-39.pyc create mode 100644 utils/__pycache__/multi_port.cpython-310.pyc create mode 100644 utils/__pycache__/multi_port.cpython-39.pyc create mode 100644 utils/__pycache__/registry.cpython-310.pyc create mode 100644 utils/__pycache__/registry.cpython-39.pyc create mode 100644 utils/__pycache__/registry_class.cpython-310.pyc create mode 100644 utils/__pycache__/registry_class.cpython-39.pyc create mode 100644 utils/__pycache__/seed.cpython-310.pyc create mode 100644 utils/__pycache__/seed.cpython-39.pyc create mode 100644 utils/__pycache__/transforms.cpython-310.pyc create mode 100644 utils/__pycache__/transforms.cpython-39.pyc create mode 100644 utils/__pycache__/video_op.cpython-310.pyc create mode 100644 utils/__pycache__/video_op.cpython-39.pyc create mode 100644 utils/assign_cfg.py create mode 100644 utils/config.py create mode 100644 utils/distributed.py create mode 100644 utils/logging.py create mode 100644 utils/mp4_to_gif.py create mode 100644 utils/multi_port.py create mode 100644 utils/optim/__init__.py create mode 100644 utils/optim/adafactor.py create mode 100644 utils/optim/lr_scheduler.py create mode 100644 utils/registry.py create mode 100644 utils/registry_class.py create mode 100644 utils/seed.py create mode 100644 utils/transforms.py create mode 100644 utils/util.py create mode 100644 utils/video_op.py diff --git a/test_func/save_targer_keys.py b/test_func/save_targer_keys.py new file mode 100644 index 0000000..3526861 --- /dev/null +++ b/test_func/save_targer_keys.py @@ -0,0 +1,108 @@ +import os +import sys +import json +import torch +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from thop import profile +from ptflops import get_model_complexity_info + +import artist.data as data +from tools.modules.config import cfg +from tools.modules.unet.util import * +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, MODEL + + +def save_temporal_key(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + model = MODEL.build(cfg.UNet) + + temp_name = '' + temp_key_list = [] + spth = 'workspace/module_list/UNetSD_I2V_vs_Text_temporal_key_list.json' + for name, module in model.named_modules(): + if isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)): + temp_name = name + print(f'Model: {name}') + elif isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)): + temp_name = '' + + if hasattr(module, 'weight'): + if temp_name != '' and (temp_name in name): + temp_key_list.append(name) + print(f'{name}') + # print(name) + + save_module_list = [] + for k, p in model.named_parameters(): + for item in temp_key_list: + if item in k: + print(f'{item} --> {k}') + save_module_list.append(k) + + print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters') + + # spth = 'workspace/module_list/{}' + json.dump(save_module_list, open(spth, 'w')) + a = 0 + + +def save_spatial_key(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + model = MODEL.build(cfg.UNet) + temp_name = '' + temp_key_list = [] + spth = 'workspace/module_list/UNetSD_I2V_HQ_P_spatial_key_list.json' + for name, module in model.named_modules(): + if isinstance(module, (ResidualBlock, ResBlock, SpatialTransformer, Upsample, Downsample)): + temp_name = name + print(f'Model: {name}') + elif isinstance(module, (TemporalTransformer, TemporalTransformer_attemask, TemporalAttentionBlock, TemporalAttentionMultiBlock, TemporalConvBlock_v2, TemporalConvBlock)): + temp_name = '' + + if hasattr(module, 'weight'): + if temp_name != '' and (temp_name in name): + temp_key_list.append(name) + print(f'{name}') + # print(name) + + save_module_list = [] + for k, p in model.named_parameters(): + for item in temp_key_list: + if item in k: + print(f'{item} --> {k}') + save_module_list.append(k) + + print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters') + + # spth = 'workspace/module_list/{}' + json.dump(save_module_list, open(spth, 'w')) + a = 0 + + +if __name__ == '__main__': + # save_temporal_key() + save_spatial_key() + + + +# print([k for (k, _) in self.input_blocks.named_parameters()]) + + diff --git a/test_func/test_EndDec.py b/test_func/test_EndDec.py new file mode 100644 index 0000000..80461aa --- /dev/null +++ b/test_func/test_EndDec.py @@ -0,0 +1,95 @@ +import os +import sys +import torch +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from PIL import Image, ImageDraw, ImageFont + +from einops import rearrange + +from tools import * +import utils.transforms as data +from utils.seed import setup_seed +from tools.modules.config import cfg +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, DATASETS, AUTO_ENCODER + + +def test_enc_dec(gpu=0): + setup_seed(0) + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + save_dir = os.path.join('workspace/test_data/autoencoder', cfg.auto_encoder['type']) + os.system('rm -rf %s' % (save_dir)) + os.makedirs(save_dir, exist_ok=True) + + train_trans = data.Compose([ + data.CenterCropWide(size=cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + + vit_trans = data.Compose([ + data.CenterCropWide(size=(cfg.resolution[0], cfg.resolution[0])) if cfg.resolution[0]>cfg.vit_resolution[0] else data.CenterCropWide(size=cfg.vit_resolution), + data.Resize(cfg.vit_resolution), + data.ToTensor(), + data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w + video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w + + txt_size = cfg.resolution[1] + nc = int(38 * (txt_size / 256)) + font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13) + + dataset = DATASETS.build(cfg.vid_dataset, sample_fps=4, transforms=train_trans, vit_transforms=vit_trans) + print('There are %d videos' % (len(dataset))) + + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.to(gpu) + for idx, item in enumerate(dataset): + local_path = os.path.join(save_dir, '%04d.mp4' % idx) + # ref_frame, video_data, caption = item + ref_frame, vit_frame, video_data = item[:3] + video_data = video_data.to(gpu) + + image_list = [] + video_data_list = torch.chunk(video_data, video_data.shape[0]//cfg.chunk_size,dim=0) + with torch.no_grad(): + decode_data = [] + for chunk_data in video_data_list: + latent_z = autoencoder.encode_firsr_stage(chunk_data).detach() + # latent_z = get_first_stage_encoding(encoder_posterior).detach() + kwargs = {"timesteps": chunk_data.shape[0]} + recons_data = autoencoder.decode(latent_z, **kwargs) + + vis_data = torch.cat([chunk_data, recons_data], dim=2).cpu() + vis_data = vis_data.mul_(video_std).add_(video_mean) # 8x3x16x256x384 + vis_data = vis_data.cpu() + vis_data.clamp_(0, 1) + vis_data = vis_data.permute(0, 2, 3, 1) + vis_data = [(image.numpy() * 255).astype('uint8') for image in vis_data] + image_list.extend(vis_data) + + num_image = len(image_list) + frame_dir = os.path.join(save_dir, 'temp') + os.makedirs(frame_dir, exist_ok=True) + for idx in range(num_image): + tpth = os.path.join(frame_dir, '%04d.png' % (idx+1)) + cv2.imwrite(tpth, image_list[idx][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate 8 -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd); os.system(f'rm -rf {frame_dir}') + + +if __name__ == '__main__': + test_enc_dec() diff --git a/test_func/test_dataset.py b/test_func/test_dataset.py new file mode 100644 index 0000000..6f860a3 --- /dev/null +++ b/test_func/test_dataset.py @@ -0,0 +1,152 @@ +import os +import sys +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from PIL import Image, ImageDraw, ImageFont +import torchvision.transforms as T + +import utils.transforms as data +from tools.modules.config import cfg +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, DATASETS + +from tools import * + +def test_video_dataset(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + exp_name = os.path.basename(cfg.cfg_file).split('.')[0] + save_dir = os.path.join('workspace', 'test_data/datasets', cfg.vid_dataset['type'], exp_name) + os.system('rm -rf %s' % (save_dir)) + os.makedirs(save_dir, exist_ok=True) + + train_trans = data.Compose([ + data.CenterCropWide(size=cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + vit_trans = T.Compose([ + data.CenterCropWide(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + video_mean = torch.tensor(cfg.mean).view(1, -1, 1, 1) #n c f h w + video_std = torch.tensor(cfg.std).view(1, -1, 1, 1) #n c f h w + + img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w + img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w + + vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w + vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w + + txt_size = cfg.resolution[1] + nc = int(38 * (txt_size / 256)) + font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=13) + + dataset = DATASETS.build(cfg.vid_dataset, sample_fps=cfg.sample_fps[0], transforms=train_trans, vit_transforms=vit_trans) + print('There are %d videos' % (len(dataset))) + for idx, item in enumerate(dataset): + ref_frame, vit_frame, video_data, caption, video_key = item + + video_data = video_data.mul_(video_std).add_(video_mean) + video_data.clamp_(0, 1) + video_data = video_data.permute(0, 2, 3, 1) + video_data = [(image.numpy() * 255).astype('uint8') for image in video_data] + + # Single Image + ref_frame = ref_frame.mul_(img_mean).add_(img_std) + ref_frame.clamp_(0, 1) + ref_frame = ref_frame.permute(1, 2, 0) + ref_frame = (ref_frame.numpy() * 255).astype('uint8') + + # Text image + txt_img = Image.new("RGB", (txt_size, txt_size), color="white") + draw = ImageDraw.Draw(txt_img) + lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc)) + draw.text((0, 0), lines, fill="black", font=font) + txt_img = np.array(txt_img) + + video_data = [np.concatenate([ref_frame, u, txt_img], axis=1) for u in video_data] + spath = os.path.join(save_dir, '%04d.gif' % (idx)) + imageio.mimwrite(spath, video_data, fps =8) + + # if idx > 100: break + + +def test_vit_image(test_video_flag=True): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + exp_name = os.path.basename(cfg.cfg_file).split('.')[0] + save_dir = os.path.join('workspace', 'test_data/datasets', cfg.img_dataset['type'], exp_name) + os.system('rm -rf %s' % (save_dir)) + os.makedirs(save_dir, exist_ok=True) + + train_trans = data.Compose([ + data.CenterCropWide(size=cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std)]) + vit_trans = data.Compose([ + data.CenterCropWide(cfg.resolution), + data.Resize(cfg.vit_resolution), + data.ToTensor(), + data.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + img_mean = torch.tensor(cfg.mean).view(-1, 1, 1) # c f h w + img_std = torch.tensor(cfg.std).view(-1, 1, 1) # c f h w + + vit_mean = torch.tensor(cfg.vit_mean).view(-1, 1, 1) # c f h w + vit_std = torch.tensor(cfg.vit_std).view(-1, 1, 1) # c f h w + + txt_size = cfg.resolution[1] + nc = int(38 * (txt_size / 256)) + font = ImageFont.truetype('artist/font/DejaVuSans.ttf', size=13) + + dataset = DATASETS.build(cfg.img_dataset, transforms=train_trans, vit_transforms=vit_trans) + print('There are %d videos' % (len(dataset))) + for idx, item in enumerate(dataset): + ref_frame, vit_frame, video_data, caption, video_key = item + video_data = video_data.mul_(img_std).add_(img_mean) + video_data.clamp_(0, 1) + video_data = video_data.permute(0, 2, 3, 1) + video_data = [(image.numpy() * 255).astype('uint8') for image in video_data] + + # Single Image + vit_frame = vit_frame.mul_(vit_std).add_(vit_mean) + vit_frame.clamp_(0, 1) + vit_frame = vit_frame.permute(1, 2, 0) + vit_frame = (vit_frame.numpy() * 255).astype('uint8') + + zero_frame = np.zeros((cfg.resolution[1], cfg.resolution[1], 3), dtype=np.uint8) + zero_frame[:vit_frame.shape[0], :vit_frame.shape[1], :] = vit_frame + + # Text image + txt_img = Image.new("RGB", (txt_size, txt_size), color="white") + draw = ImageDraw.Draw(txt_img) + lines = "\n".join(caption[start:start + nc] for start in range(0, len(caption), nc)) + draw.text((0, 0), lines, fill="black", font=font) + txt_img = np.array(txt_img) + + video_data = [np.concatenate([zero_frame, u, txt_img], axis=1) for u in video_data] + spath = os.path.join(save_dir, '%04d.gif' % (idx)) + imageio.mimwrite(spath, video_data, fps =8) + + # if idx > 100: break + + +if __name__ == '__main__': + # test_video_dataset() + test_vit_image() + diff --git a/test_func/test_models.py b/test_func/test_models.py new file mode 100644 index 0000000..35cb708 --- /dev/null +++ b/test_func/test_models.py @@ -0,0 +1,56 @@ +import os +import sys +import torch +import imageio +import numpy as np +import os.path as osp +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-2])) +from thop import profile +from ptflops import get_model_complexity_info + +import artist.data as data +from tools.modules.config import cfg +from utils.config import Config as pConfig +from utils.registry_class import ENGINE, MODEL + + +def test_model(): + cfg_update = pConfig(load=True) + + for k, v in cfg_update.cfg_dict.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + model = MODEL.build(cfg.UNet) + print(int(sum(p.numel() for k, p in model.named_parameters()) / (1024 ** 2)), 'M parameters') + + # state_dict = torch.load('cache/pretrain_model/jiuniu_0600000.pth', map_location='cpu') + # model.load_state_dict(state_dict, strict=False) + model = model.cuda() + + x = torch.Tensor(1, 4, 16, 32, 56).cuda() + t = torch.Tensor(1).cuda() + sims = torch.Tensor(1, 32).cuda() + fps = torch.Tensor([8]).cuda() + y = torch.Tensor(1, 1, 1024).cuda() + image = torch.Tensor(1, 3, 256, 448).cuda() + + ret = model(x=x, t=t, y=y, ori_img=image, sims=sims, fps=fps) + print('Out shape if {}'.format(ret.shape)) + + # flops, params = profile(model=model, inputs=(x, t, y, image, sims, fps)) + # print('Model: {:.2f} GFLOPs and {:.2f}M parameters'.format(flops/1e9, params/1e6)) + + def prepare_input(resolution): + return dict(x=[x, t, y, image, sims, fps]) + + flops, params = get_model_complexity_info(model, (1, 4, 16, 32, 56), + input_constructor = prepare_input, + as_strings=True, print_per_layer_stat=True) + print(' - Flops: ' + flops) + print(' - Params: ' + params) + +if __name__ == '__main__': + test_model() diff --git a/test_func/test_save_video.py b/test_func/test_save_video.py new file mode 100644 index 0000000..5122a8d --- /dev/null +++ b/test_func/test_save_video.py @@ -0,0 +1,24 @@ +import numpy as np +import cv2 + +cap = cv2.VideoCapture('workspace/img_dir/tst.mp4') + +fourcc = cv2.VideoWriter_fourcc(*'H264') + +ret, frame = cap.read() +vid_size = frame.shape[:2][::-1] + +out = cv2.VideoWriter('workspace/img_dir/testwrite.mp4',fourcc, 8, vid_size) +out.write(frame) + +while(cap.isOpened()): + ret, frame = cap.read() + if not ret: break + out.write(frame) + + +cap.release() +out.release() + + + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52a43aa6b08ead9932948ff62645c984647db9be GIT binary patch literal 190 zcmd1j<>g`kf|lOtX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o11;*(xTqIJKxa zCbKv*D?i3LKR2y1)HA+3GcP5-yg0rfzo;ZJDJK;s5tCe6T#}y~pO>GKS_~7656#PT v%*)J8EJ=+iEy>I&j){-Y%*!l^kJl@xyv1RYo1apelWGUDx|j(_urL4s^l&qM literal 0 HcmV?d00001 diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0d4a3ec81a75a432c5c101a28d991935add1188 GIT binary patch literal 134 zcmYe~<>g`kf|lOtX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2DT*(xTqIJKxa zCbKv*D?cVQFVitEGdHm$HKw#AGp9HvK0Y%qvm`!Vub}c4hfQvNN@-529mtT+K+FID DCIlWg literal 0 HcmV?d00001 diff --git a/utils/__pycache__/assign_cfg.cpython-310.pyc b/utils/__pycache__/assign_cfg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60b8b2294e90a30101b6f27b4ed198265812eec3 GIT binary patch literal 1443 zcmb7EOK%e~5VpPE%_dDs%A=(~5Gto|Koky0s6wrZgpg1#RfPya(GqW9L-v8~q!eu~ zlxxM23y1XBzvL??&Ky99O3dtTL~=j^NB+j1nRx7(Z#J&gDg?%dqffnuWkSB-W_B>x zY{FC<0D=fQAgx@|7KJ???X{*muYhoW0SUG|n zu`y5iq;yOkkmsvyHnzu(c2h@}51G83+9??kkTkVd$)u7x$CQ#cjnqNS2CSaI8#v+Q z(4r)A*NNKN2M_!`g_*@($YS65&)7*3J4xAG>}L=gG!XmQY;3p>>#GW!)Hsi0Pxx}DS(R8Y!{fw{OgWK3EIZxaOT!PmaA5cV+zgjKU}e|q`T5_B z9cVvPrV^+iQre5UzTqP1YU2zO;c4Hvqk#zVpk~QfAS6(Q$g>!V_RIoWI)y8JZRqHP z$cte7+18=D-cU-2ia8OBdOBT4!fzomLq8 zb>(!iJiD))y5;(Jw=6G#`(ODeIas9)P~=%6Q?oQk>>{~E1XA=u4wMveue5`rC&IrCf=GyK%c{9Pv&6dUIITx@3*()+mWmUfgc5Zzs`9Vcl@S_-9wQY0?q00NFxg^F5<3+g4R5K&OEvQBZZov^!!C~_{O zM=EY0aY~N;1N||3s>F%EzyUF1yU?gYDx6u%Z}!VG>wWXat5iw^+QH%H_LCwZUvY4@ zF&J#XtJVPo5i}%?%+dyhF$-B_HLQp=7$w3I{D3sLu!RHG76svgbp)-GTJalrAlo%& zT;cn@R=2lrIAV+9Lf#5sU3fR)RaIEgBq1X@q`769CZxjz9kGuuhnBWFJh8_1gh0;` z?2rw4!pDUP*&;6PlRWlG!lq&$Lu{CV*vIE$|2P{PPRQD@#Br+K-ee z1uBS@_TrXrxCmO>IQ^dRv~S!&C?a%JD;NumL~`UDhN5j##+BWCD}8O~;E2dYSp0r> zbFI0fe5sn?y3=j$#KGe@h`?pDuY*uEJ*9%Z7{9*Se4GxD!Vh)mrJnYe;)l!9!U;YmS>Gb)up6Lr~yP`NswiP8H5VLEW#y(Du5{r zRGMG4f}F=m%m}%Ra2a6%;R?c41Wbtxu0hW%=WM6Bg^qp`ejC;=Lilyq*RlXk+f?EB z0<(F+ur%$kv4!$kA`>b z@)EKIGKLwD@zyyRf0&Z7RVUfiS2Dq}t}?vm=~u?_-v&ynjEoENaC-fC-wj*GUgT#~ zq%8cAfVrFilwc;b4RD~4NfcTYo?k+vausGF&J0u&bS~fvzK?=yoYl0&q-Kju_4P$v)yATsYID)# zIyY)o-Cner_PWMRZavVr<)t=_MTfHmtzd6qHLbX%qbx9)sx_9EtNwDCpo2dHoeTJa zk5G6T?TFn@aD$r<81}Qc{XoNhDej=Qc$#NW+kAv)QK$GxKE}s^>F{Uy1fN7J%@6P? zjAi&iehBplKg_35XZaC+6!j>7iXTHg#-HZLQIGRycpi0*>!ACTwC5XsrMi5xSSHT& zekn98eDnB%HVV#SaHGXUn(MJ~So0Z28EdW+LmhDvBk}f z4*XQi#MM}9I~|&x<`&8f`XjNnng+f-&<1`x8{4to(Ko0)iWWnQ))_;8ocavXBMB6x zQLZdYHY>I0>9RUhFn7zZUi7Q=Qs}u=zv7A6^-{fdvyhXvTl3bun$#PCG@GSxMW%RF z_@%lhQ>EL1f;g^Qsn$H#mF8-r>dWzD)dW~;l;w!qEQ!E#OJW(TsPcgCc^rFH!)mEk zjY?s);meF$Ym~T4d&?=e?ulj31+8vHH0ti{QsBu6x7FlCr2?sy9Kzt$awxOCLFyG! z(hR&>MVhpxI0V_qG@&NvnJBk}@ca-ng;x$6Vtp|OB8MPSElvAA)Wp+R;0t&D?9#2k z6Twn7sIE4a<{S0O`mL+(y{gX}_X4-s5Mk+d&Fc*;m0Ljw+WiLif*#P_VfIohtk#01 zgkQ7GbqQ7yPh58d+XY8a9A#5_mgShKXPAu^{n27$EE-Q#a<(_8%z)qfk%ig#lx7Mv zLT#0C#`RCx28*GKWJ{z~s|L3}HJPShuqU1lO`hpPU{QROtf>#lDagbCWD>ST>sx#( z!D`zb(zS!d=*&g2tb{M}m4?VCCv6GXd_G+9(B~d87`|WxEGmm?lQbc+W@bo9l>kV$ zqT_Ev_wr$*JBv)QC6w0WX}n-a6Z1o9u6VU3bkwVqTD345otaVmtb5@~gGaAi9a6Y_ z0{(iw>X&OR%+KGPe>)E{gB7R(S1Zp1>p|$%qp7gmJUcfBbuQIb8bNsW75oZ@cowV` zr%*^UTyJ_&YGwvUYAwqQBriNS6fN&cG=04lHd|p{4PiU*0>*}uf^fT3rj}YnEf*b~ zndxfOOc0iQUaB>GZ{~KZTI11ouli7Lm_U%PRunfBJYNm+rTh+dK7kU?-%XTLnvG@{ zL{rzn_Z>kfB7&iFSU%Ym?8iZ^WGBZU#a120@iy79kl7OFWX%dD#_9A%5DJ;Tarv!_ zSBsY}-;fzKQ<;207t@4dCyJm>0)m4P~;2-j6&ueabcv{NT0WE!idQeP_df*EG0(So@B)q2D!QigCur`VrD+ zvna^nf5GOqw((PED45CN(&P%=S`}=dz$#gD*E!+c6biY>*MJc?C`{(SlCvzD*dZCq z5YCe+hTVd`kj45Y(VgBfLh|>6&yASP1ySESk?6K;q_(>+pmyW%1LV_+VgQ1SS;Xb{EXJ;cX;EaZqvViw*8 zzGzeb?7#*KtyMc_?NmoY%i-pyh-dF2ZkhOR8t_qw#V=~XbN3L<=>Hx*f`v&WO^2Bd zYmab5S8gBCcL$2!aR1)M7#E4ecW+C1-c^;001ou0sHmq@JrrY_v^?R!5*P zAX;-!w7vtuo*rbsKOfqp3T-F?1Zo{oyL5XH?;TxSqkY@y-FQ5O6pB&GrRX9G4(?3> zjSs7sqN6%U2X|t#o#|+6Onk!A1MSRCTP0%*i5yN9n+qE6jcMY~F~TX7aRgr$H2*&7 zw-Ht;?Mu?N?x~P04~|Z2{kOVl`)nQb0x%$iB2%eF`xzIl~?l(%4s`aCJNeg(~SVHq1WH ziN;O9(?c}!1hOxU3O^a`Q;QY$rIyCN617OLyZi)ibKrAd3kta;V2HD@7x6tR$We%I zQbEaDq|eSRrrp;rE?jnR+$t`PB+a+3U%LF}qNO@pgeNGI1M=)mBub-A;UNgGc`79? zOo$VN@hTM;s31ooE>b~ZMtL1+!|#{xaX}$gri9n5mCBwp?|AEhJia^AE;s6RWF6gX zT3Y2s%MZo3X|V%TOwl5T_nB32lxs?YN?P3u=(Zp((@b(!$P9$H=Bh;(M=2vGpQ+jDD-|LhW>SL37uGPYaI;%__-xmKCIb9 z_+s!CnmB=gY=sCv;u!P-@YsYcTOFNgNGmMrBTSnJh~^p#o{J6f!VZ*zcR?>olT@Q% zNvnz+G!WOYrOe>&;ZvqZ5~Qn+5=g;L_)aWbOc7gs@1CIOE!iE()^@X%g!mc0fI?Fa z$IpN%89l2zO#cw9c&Zl*_PB-(#RAHR`WKBDosj%A+r(u9-Y%`ea}S~)LaK$xaQqDf z<~DrvD(hi)kjNl%0D?=3#`TYlBN~!1>YM$(86t+n=B7@m8aQsjx8d3wTj1nNv{S(< zD#H?#?tO9U`Wj+$m3w*0%km#=KWYSqSO1Z zci8rZ7)=4U&*%p@V2^d2fvq~&dhm&CRj)O>*veF2v6b@3&)NG2du(mvc8DO8!Y(NY z)4s!k*RX!7ttXHUL=1tYIE5&^ZIRLk3&3z7SOrZmy7NgH>VT6GbjRUnlUcRyEIXOu z$K5$;%=zU$bAB>B2Uu)ynJmrsC}Y9=Hd*Q+wknd~j?4-&XjqJ{e!l z!^vo#GPq1p*;-j zf8C?IF>uHvql3kUKciCQk{{+_WXe#tCN02b4e1Q_chJXK(^P)hc*o zi9MIJO_(M<1k!`!z|)^q99oiGB>9sCTJjAP&@W63SxXb+G17E|CfwHWVq{e4j+KlW z7(EfI4JkS07IgJQi13CjfgLl##Vg|#+)-2}LN9NEUa}I!aM6l%Z>C_0(*(Rp1)WUH zppaGtp*O^f0ZB-jUd^lT5{|U+VAm27;{jc!vJlBJL@kF!CwmO--b^Vt&Yn|GHo^H> zHSMuGDsBA|twqVkk8w+VKLZa4_I-l0;C8LQgFEJ3_|64QybKgY4PI^Z*>8wj)PD!X zLne4{Owd9yeQw$Fz57jZ{x9eTw2S`XF)cdtL=yXxa6ea!xj}2;QIJl8_fJu3sY-FGZ%EnFJsyD@*y|#8; zWV6C`|K8hr=wUP&+kb0%d!mO?vK93>O3G^_E)A)+ji)obdeQ6RYJdl7i|7cFSCc)c z5p^v0j~*f5`AC=NH0?1W$duLjNsB|AJyvH|T2|1dS#HpSF!#JBwe~!HnnsaIH`o8! z+cm?YTyNLjV8OUqEJB2osukFk!UU-lW#v$*jO*4LyrrJ+v#xu$RjMT;;u&D8$GIf+ zr^g_1g^KC zC+mzmbIyokI7Ze`uAz`l-cByif=wy}6<IE%fz1{##W5&I6>GyVk6p8j~< z{rbJv@2fmL?HhPD?)ND$%VV}!J zA#)#l8pR5g{~i2GJdgh@ zH-YCN<(_JV)!OovVuk26d>mBhJB2UWK*3o88g-bQ<7Q$l7$NAfU?fIhBuuU*MvwKW zWyV&|=626XjJjEO`>anCK`Y1r#-rXoTeHXNChFP5PRzc!Ms*)Gh8pU@6xumzV+{k? zx_NkDWkh^zE$S6a5-Pm2gp}ttC@dyZRtv3y5t58u__vBDrlEwB!aSBj+CuS zx6xW&7E;-@u-Z~?wI!S7Sb11BsMZ=n?Fxn9K`pGswQ{4@EyuN1sQjRT&4SzIW<$*f zO(B;>pl4vu;O%lG)Xt#O=H*xfghslbUX#}uFSTzgFVKhw!4$EL zW*clDvzgE4nagte^-=5YNSVGggs(7y9XuxfEoWK_|rQJ zrR=3M7f!rZJbmVp@;7f?!IX1!1~+A|7U{HDp!1QO?ZFqlgu;NRnxDBA%Jluzbw96e zf9g8jxuH03?pq-#9zLYm=+hGY6e^rF3_ERHpIx({X)JCo7~e3~%$r_97QuSo+)I*a z7bPjdPuP+BpW)GX7wq&v**3%~Xw*nFsH#|d%i*9kG)%o!3%`zKA`b+Tt%x>dzZs0}i0bV5z!KcYey%MV~UU?gHFi5|$@ zXY#GsCS$s8{^HOYi=DcguwJHbpa!FQ2S(scD0>_4y2Xr}A3SSB=btf{e!qt@yb{}b zwQTJ7Sq}!d&$zqS;F$$5?*CggoW-iY8(%dcW9Svnc6Tiu8R`R4ELsOBChsxjlD!d9 zK8>T@{}05WJNF=HAt!SfS7s}cU!{sGZb3v~d|GW0Dw|Y_+(iSn6lJdb%jK#_7gip6 z8d{W~$4NB>fpVgDqZWgu`Vz6KBr0#=n%c2xuC!S?0hoyFiUD1p>irkYgX}rh-96;w z$iC?bTFNW5^zb1@j+l86Bhd0-ONT`>=ZyCtgq+khZT zN_&EpAA#uS#)%&OIEIJShhM;LFwxnocNw(6Qg~RuqJ4YWt#CZkhuD)?X}k!7!*n4Y z#DK;pMM^XOuj6=dFR^=m-&kdLncU6cjf}PY@fQ4%DO@)jzKWgLmkmA~H8Ar}iOqco z;uOBz4HG@@r9C~^LijZBf_l@5IXEAAMH)UG-D8l>8T4p+Bo5bo3kK%8;ZuXLEYGAg z(dCG8ob}0ra+73q= z`$9O@HIC#f+_Bc|5%rC`5PD#goxt(Kv}ao!HT)KeVufw7rLixB?IuM*-Pv(jD9onT zLOzGPBcG>&RD~>1LH_6zxX2B9whY+>X?N zt=?;;)ohkSKJZYLQ)zX=ST53H(^SmRA`9EhDtKBkr3NLdbrv~k@>?{M^hwDF{;Ppr zv@}g#E@@S*-Q;5gL{4g2Xn%mVJdZDW3xxqY>A?h%@;ZR}JUe6_Hgn82XV@It$4w7@ z!k+)yHq6wZ?}hO`z-)yhCwSBLj=cJ_6(aiZA+KZ<{aix>x7O$zFuTQTxcRurHjEY` zo>fDhgc){X7#{f&Bm(`3jeG3$O&G7Mi9>A|XAfr3US&}+v81=zOD4Nu7E3k>1p))7 z29Ge3XRxL6;qHe~Tn>qM6gX&B6x@{JM7&al2pWoa3*s z*_`Qvj6Ul=KC}l9327$m0eKXom>@wTCSA?e5f?y9&yo2Y^_aWiU%@&YI|qZ@gPyLl z;aDHe3#rJzufO)S{TLZ8f|E-81QtZuDc$ZX1R zYF;(@;XXAYW#Q5&cuI|}bUlmdKgSo%qaZ=G%{?p!=XDmMILCam`{XvXoxE@pR3?i% zqVg|sx*^>-15HiP^zb8TssL$tPPM85}w2d?Vg6k;%+B;h0Ull%|Hp?3Mu66YHF5{D zvf5Ife)zn*ESj5~2{<;semZZs58ywZgb#l^IvrW$kGI{YJ2~%j_xZyE@USJD=N{sf zuCqOU1MHAZf=xD|SI7$*P!IBcX_K}XKRh`ONKNfAIHIO0OAm+B@SxrpybGRPdNj12 zg|pEtef#{Pd>w+Sub{k0HM(5dYTehS>)G#!+XzY92)l7JTXHGyUnt+IH9O6;_h|lk zKmV$Thpl(lx{noG;i7hx%VC_yrPH_eKA8|`yzkQbc`Ed+&596lr$h!sn>ej5$1Tz3 zEtZ!tNm(_1OI|^(l*P({cB!}C=Wk>9H~6BfC=52b%NFi3jEIYyy<3+#m?HBt-h)Ms5JPJy{Y#2X-O8ALUWt4*1 z@KPgm+E~Y>qU0IHSYzojr#P`BX-m>BM3bn90!Xm^iE zatu;D$#qfLW6L5Gx7zafhiLvApV58lk;siEAUM|~IY;Xt_@gr*_;b@E4Wd6yRHTS& zm2^|e%t`Gu?L_U*_#)!2!Mw~q*4_1xohRr(o+f>#@ac&Zt`AOJVA`X@87b`C$_o$f zrOt5vjiwSCjm?bWaO{D-h~|2)8B{pyf$ro%^-~Y1YrVVXLjv^uwt1)7mKDG~uyqi? zot2;s2DffK9NT_tiis(iNjF6&G%4!0QgsGj+pUEhDK1L+)gxqm9l|3@h4~@aNP*@4 zF^0(0faWyq0VFH{>hsePW1KzFXV(%|FqK_tQ396>(Wcr%PoJex_|om${~X(OcDUf BB@qAs literal 0 HcmV?d00001 diff --git a/utils/__pycache__/distributed.cpython-310.pyc b/utils/__pycache__/distributed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1148ab6af6ccd6b4a4626b6b03392d5c454a36e2 GIT binary patch literal 13026 zcmbtaTaaAGSw6R!IWs$Z(MnqDW@}_yw#Tt$S+*QYactdUCyugCEGa1_a)#0L+12dM zO?}Qt*81#ARv|K=K%hu~qF^=^C{%$!0pU{I67E%bp(vn=!wW^IqWp?pD2lTAzQ1R7 zW>!0~5_aqKboc4gxBvG2e|KwgGH2oM#;va|{L}|6>w7E=|9L2!z!Q3oWhqP9b*p54 z?S|u6w$*S;o^5JUjdUr~%ljo?T4zgH`OTFk@axp`jmgrav~lZGjp@>~Z7o{L`y~;=-LD?FVXxVx{ej;t zpk}`++^|qSpbn@9G3tZrA@wkRAHtXi)g$=!Vf6Q?dJMUP>XWLd9!Ktxp;14eooVYdmMGosArM;0P2pa=a73seNa7*k)Bk?)N%YCQXf(; z;Pf=&#T-oQW;`w*TAM-^!nRrdBKnLnf zts0@|zuak7qgtz3+?Q;xg=JL>qjIfTi)xj6Z6#32)M5~oueEetl|xh{KFhk&Tu!ph z;97Z6w>s@)rwYPIx2~7lx>XIruveYARMEN?=p=hdw<@Yy38N%esn^RGzf%p8bpLm; z5?w|!dDqKNn0J^^HA;4J(8YS|Ql-wHRlZ#& zzE;EPu7%}xOGoUa|Keh`6Gm8evjyVzE6W#}wUfou zP1W-RC(aej*ix?A$oK!4ckys_cXl9&tTjvfk==D-YaP5we!cC?zr%(_S31xKkk=_B zx9ub)EH_w~KG=Wx6f$=>?hacTj~!P(f<{GWUf(rFntc_SY;2tApD&KE{sT-{8tI** za_Jw(6YfJItI9|A5>~Zf*R0t7y8R`t4)oInG#FW0h80edqFjYZdih#KFNT|^z(z;$ zgd9C>&)S}?SvQ_}=I~Y~lI8ZUL-BL?t*79l?a+v_eA!vPIF6?`Y~J?X)2ti6{zvdO zvSR39WiNYr+vZ&un#QAOik%>74Uw~~aP;G>c#_E`I`I9Ac;DHydYkHnAN$3ohv54$5Q1!&=b*=Z4Jkgke!k4>aP$HKoMpE_O7JC!9i(wtf7k?FVfT$JTsz zeuO$C*9LWR(2!sr(#4BCq05A!%S|y#;r3Sn5ZCO_+p(>m6dJk86B>DjMyn~(Cr*Jv zLjnitq|R9pYM0XJI~`>}KWEKZ&7fajrPl4YttcDix)X7F9pG-&S84QX#Tn9lH5=vQ zEPivbALo``eFZooKt7Ois#%{^$dn_@ku+y%pHSU^1U+nI6#Ts48ibr~) zu47NK7;cd)X6mB5YOgx0ZVa^$J6~{iTU{uF4RJ1@ArP;3wH921mNl!161$n$ciCF5 z=_(#s)5BJAn-R}R;Hv>e^+~iavP#KD6_?=~uL z8UVc#=vJtYNhJ_!;;z)%#k@#;Dc@<<7Frt66UciY{+ifHpJDPkzst!GBKnCZAU(^} zRQMRlAoonWqExN9xLG}L?w`RE9ztT2!;I}Yb2P4Y&duSOvZtsv@H=bIxpT-tjlAbU zNppOUXm&)05Fe00qT&A@nb=YGRX~VWtdHC|1j=lS7JNzYRnBG^}p|A2c315sRn1m)suQ?nQGs(mm(F72PH@JE?>@>Ev zF=SKPn2<16STwzgb;kB?Ye>?B9DV6>mxJq&om{Qi?nL7I7-}6v526BojQvs2c33O> zhg_tsw17yWra8Xt=z#^{tTCRT*~gI|5iP?>& zXrheKB+5#>=sM($EL`+;h3~kmBrDrP-_{^Wk(veh0Fw+$d9!_LXq$hOAIn}(Be#Ml zd=N?Ap;%BD@^;>tg)r=baQH~`6blLh+gI^y{L!`GW^&ZM|TxpoasfNhata_v*we zFnf^t%xkZoJ)>Lvl8SPOn`Th_X*?nK(08~JJdj5etE5H$kHSpICMzKwxe|&L7{ZA4 zgow4%by4PewksFV%tH$|^13Pdb%1veIfzsW9>LKl)Age)YGGEpIhF2CMENzV-(oMd zHuN@x*IAnEPQ~mg9XZ`;f`O%(*j1TZeL?i+bsck(eM#JEyOVg;j;hcgome^kI&f0q zt<^%LuS4_ECc|((X~0cG|EUXq@IaO#a6jUXu6k5%K3a-|w!CR}a6ls%(^5+)o zD@>@Ubq7h3Zr7^IRQIaWXorbg51L7)-dbG5KuH$lIEy^u4Su=YtTcjhImy7SYw~TH zO1%;Z3nvq&u1CQ+V9d&5kW4ISxZIaI3kwW~@XdMSdN@(5=)>p7ER7*(4<^!Up61~A zECaWloee=#AwBsz*?};IHnE*@)G8aS@(Xw)pdbepn6qyg|a~!GKRC93a5Y$;dnsjJ4?{()a8rJ3+@BEyQ{9~bp{cAU_c;A z(d{1k@`3zm%I#Y&9L5a8jIq~ynqo70)kvm3oDus!=CANV2M0DyTvDXry3U&nIk2$+stI@8|&zzV3A`)N~ zfd?sp3W!lPqvxI#9vll4e+!@8p15@PZ{emY#`hIK^P+m~hE?+XgU#c&#}w#VmOg7CJUxWaL~N<5(aHbZ%-X9ZPEp z0b}*hz7Q_N0MO!%UFae8wtds?`s-FVizkPuIUF(~=2d8EYGb%l(6!XTSB}NT8f^`7 z`iy;NGh;){3}f41Y^%@M@R-TdxBEOTPPP6LssYoIp({p`UxZ}j}OOq z2V%PwJByt65QmsIO=Tv;F}x${U08`ALWf8a@*9G-nW%LMS{PuY(r&}vYI+4ps*Z4W z2p0Q1+Y5^cN1frzES0oz#D($K*hV<)=a>_CLFeGxc)|h_03xys^e$N_?>b~B*8|Yl zWukm+J%X=|X!qFYAMv*5qq{JTjFwFFSjiDqLdXZKtqEFECs9$CQ5;&kTVR? zi?ozsh7n)c#9DpK?JQ;Rjl^2OTsTAmb7dmGkC27Ad||F0V|6ElxqM+R_|x5dM3WBV zFU)m@Xc}4?y+SKuHvL`&0DXU!(_hBQ_0KT*StjB(%N}poJHgEV!I#1`j|?--_>hdg zgfEKICM-!jWHb^5cQflSej?CQxJG}4#ji5?c_w5QP3v0B8CE&UiVc;dPqDbcIrB+ILVhnwM!3c|d;@ipx5+uPL!2|~AjHG(#)bz)pMiVcKRigA>8>1^f9?E7 zjsqywuNT5jn}-Hr;q)gj^rIe#d^O;;!jlE=K9noqPl$*U3$vXfYqW$wu9Ti6-#aQX z5_1GqW`mw&P6#=mCcpbHJmN{@T?sF;ZtU{(v+U(jBqQU~AuFbmbgL7=&+5Bd5<7c_ z9|(oEngBoiZ?<@j2}>h#>7jzp$MM`@6C;<9Nl)*NO~j0tNFrw7dY8(5I4)+Q^F_B8 zXL^cpCLT%MV}uD%LCoeSF|^^?oybAiiYgs2-&RXw%m2l5$48R4XXyUFt64jCiTaoE zc672lIQyX|JBRM{VJ4gC&KEF6mhSQ;OZPaZ!`whQ)cc_)bsvZP6(*Y|#TV&~NwM?; zP3r%(Q}oQY+$p}8*q9VcBVdlWgU@L^;R8qrLA&tc;9&t_fEg_WVFkYL2I2#4>uk{d=i5#howP?yBJAY>FBM4@n2TFTNaGj~^bZ8sXtV3tOXfu?|!ST{1) zzA$kDA$>gKLF*~};xJ2oVbH|jnIMul7zCzb=p1n^1Q|uC4c7v02;GVo*c(xjweG;r z%)s4%|E*HD93mpDJpDcg2U|rj;g{QQ8a5FtPsFm1#5jS>%Ko8a<1JwdT3QAI=Zx?^ z$cl%Uh*XRHImx16m1E3(8cC8H#3VtUUuC^m7E&1qMm)$K4ZhBooLn+7G@biM3mUq8 zaQkRl7~RMNd}w0S5vp*KqzL<#h~f96N{pSvdjJj^dzOqYHbD!$X?+vuQp*HKC9OAA z6RtVHsZvF8s`OIzZ4^eP$mx-JWQs2%b1zc_GTLCA*jfU^5LDkn%nM=Uw6H|)ENKu{ z0{}dL?l-T4bV^dy%bn)3{x#$w6{9rSq)C!=Ic(Q!(W6{i4j&Atpud4)70<{Y7kC2C z9a^`*D#2qYL553ifRHMV$l#KdQpkg!C7(tQkKH|kV7?uw=Sj>g-VtOarfcWGZI>d{*(!NZO6;^ihqW$oC<5H_$ zt3Hio-(h#Yvobd_t3&-JV@KED#CN*Dgg!6EbgR*u`gd5diHR}@x*!0fj$%0dduXt- zb9AuY`+FD*D|u$TB=iL z$S7YfU+Glp{Qy$AtXfr8BN!Uhs*NDJ+*0~?@s<7qCV$Liux`mw*!6EBImCUYkoy*5 zI}Th}#P%J;^yTeXCey!xPlg!gb-BBs`p`-0m-n9lwJMXOa`--33E*+?%Fal z_7B7sf630jwZ#bBurnJY^m(B-!Z=<)d&Bh}GBKjHhWBLhf3(FQB*;g(@8oZ9z~8-L zHfXQ^Wc1?!<^BRMKFn!;hb>qdOR4aE)xkhD$hAGF38)3A13#IdESxh}q*nHAQU*PQ ze$Rx}f5^FxBkOBijR@e5p+SEjqDuc;5tVFnuHXMXNU~@RJ&tmp>G!lUGSdyJLH|=? z^`A2NGbVq|%ZcZ3QU9x$C2@Koc>><*~+&5%*NAjT<2TS zkb-$>qw|@K&d13lMe8)DJFo>J^y8?yz`lk51{3e^&HQ*l{96q1S$_H*j>XbDwACY+ zpRP8s38xdzqj7960#bLab?a;P>lO~E-Bk~9YevdF`hSS{^6b(>d?pj6km7a=U?~nf z6hJA?F0%lu{$9YRHQPi`@Kr8Cd?(KChM6}Q1;IO4-sUh?o>B32XBBOhklz>uaTj13 z?>>BmAmD}7NsKgwesbOEb%5a2sd#E>29UwMWg{Lmy*d-4Jz`19wwf=-6WwhH6wWFq zp6PDq(N1@V!mW;VUfOd3pm)X-YlL`twv*pEwnY%pyv>`pI2G_#x5X3jw%usUw0j>; zaJ_X0hd#M@8(=vu1IIWL!b~!V;ODNKMg2@Xz2H($;RdU|iOn5gFo`bckD%@yGD6Ww z?6#)=UECv~l*H#X)mAg;9kP4%7LLFP(F9Z{I5BK093M6$Zf$%UuNu_q0<8_wzQEQe zkt9x|p+CXWd6pz1Pit)SW*O(}0!AcG3tls5V-E9{qX z1AqrXyl7jTO;QVJ8zy-K9?R7_P9Jbr6D?3eEFD>USb3ua34Uda>kxU>BAq)s_FrOKjD+#&+^W{rFJ#>2Q#lw~h5ZWioCV4RNeF&Pav zjrgU@FXasmD9xL@`@8ZIdN8fUr9^wIg^|#KV=0cFu<>~ulaOJP%yJO4acc(GKj_;f zZt!N?aNFll6fQE+LE*>mvM$AhoT>i-30Me$>8Lo7*q_4J?NI**+6%*!K0kZdz8hwdLI6GH{@%#y5p)oRs46N~){$CLQQM74!;J-PZGnyUv>ro&1nYSDIT za%++)i^iEmKdYwg56ci|I;UU@c|WI~@k$uxO)9i8*O zK$5NNPPtmC;#xRy`A}?`5q}BX>g~+!V6u}5V_%xFCe7H61Wm--GD!LxtRvPJpeVRP zELD=i=|%mz-osaWne1b7KNE_mp`DofL2?i;68{C`i@ZqUg$jv}^I*C5ARz)?8SwYx zmJt5^RL;wJv-rnvHvJIJHFhGjIR?_OWw#PTwTecKCD%f4yZ`&&t?JZN&cg5N_3tcy^W&EFuPluJJQPmi2|dTMl%?#3RW@(C z={S~cHQlmjo0?QJUCs>he%Y7a*>YCixpE$Fr%`B5m8YbS+n8?7lxJ*f#Zul6?DDKi zsq|H=yhpvLGRnVdm*>=HRaWJY+pCsTUKNnrr>4|2dhJ(d)Qp-%&AfU^?NM{c9SAb3 z2Uia%d)ZNY)jrf59;?}}=23G*y{rzXgUBtYSJYv31i4!RtAFdJrOv8b)U8+T4ZD0> z;P;EDxlI+XS}5PHZdZ3;);rXl>Mp$RROi&)>L|XwOTDV@QTHNuxB8qasr!&SIyURa z)cvTrN4=&VP>aajt3IzDRL77jsV}I9)N$nQQ)TtAdIY(TsV}NW)nmxruRg9G$4n2X zC)5eN7u6@!lXyS)yrn*=p1NwOr-GRcw|orcr`1W6PfGbAl%G-0qWr9skE47_oksbz zlphX`tv*tIGjrw|^lIfKos=U$G4OIzI zk@zg@YHKaYwt_b*E4tn7B>PnmM!J2a($Vc&5Qc;5%!R7f^*|@t3%XrZwQ3k8xoV?P z!TjA?kfew2m1=Ym-K1@h&zm*?REv`RoOGqpzEExO)gFvLzY??pZ2+xs+`+eVA<=F{A3s%j0v6Dh7)w zNRQ#@St-@poh!*~*uf&}HLPK+8B|+g;)nIt+QqhR)rg((ANbZfg(sXv5?OtF(`rxS z?QB|StsC}I$xS?>Dk%VG)yvg-qk5qcB-1F3A@7Ez@51+z{P*Ii6X#zI109~PhxOI= z`Bz)@XIk}UH44smqk1De&+&BqLN^N3V&{sc5V@Zq5Ejdpa?NIb@MGS^BXI6RCbBjx zJsa75C$=_0b)>u-&eA*VSaM|meLM0xh2(~vq=a}zAnLn^jSnI-oR^(H_`D@&N#8VY zn&aQgWNTZ@oNPFbXE-Oz6LWG8xFMfmrv)Uk)`O9~inT7=bt|@iX#Wb=4E*RKI!r7$ z!wRRIBHx5bdhLy>UJ17Y!A_6j2|2rEr|mggvu=0dm_w@oGR;ljgkax!dLRA9W|PIs zUbEo@o{($fl$tfWulx?QMpg{zs_ZpS-@1MM0|1Vq8`e+E8H8t9;ph@8?qjlT|JnW& zo&g*MTeGG{>KgtCAs@j2CjcDI0FFBh9I+ic%Pw%?j^W}aSg1O0cu6`6T47r!Zm4U! z1IEtJ;u!(PnvVvsJVq{gnmm^{?auB%vEy@iMnJKqCOAZfKW;rA*)38$cFP^?)?K6B zirr-o`{ljYeiiU_DP_QN6-2}$Ds~5s-Cw{n0**BkJ4T`=aD+N6BOwj$n_!Zm9J3>AAZODDboHdBsrulkV_r#9@I~+~uQ)VGuRCvB0$KL`L8fw?DVmK_aW>9u zxH2PX(unp_-q6$fqwP+(I5IhkI=Iw@)esh8$P^7Wic!0GAt=_`t*E}zZFj?aOCOpH z^)cW@Kg>ihD>xSP!v43u#QMjPtlv5UX(wuX=Ux160)hlcVhs-?5nh-hFNh+yx&LGQ z0K>zrx1z*pbrKKOo`P6^1fL~#*AU($WOXzNWTK3r^IcVM-UPt=Wybg31wEW=Vtfy^ z9A6el;ADXd=iQ#&b9!!UE0E^boI_S0f@6#Ii|7dJ*t=X0-hfOtD~uAmmDo4wZDQau z9$9B;k74C0SYjj0>nGSQwhy@wBBk4bxt4-~*$MO>R%JjuQlcKzPcu2mdbe49)0o=p zfo_NTA*qBdo4D)sPN^UgqFm^<>dS2nBNaCKNL1)gaGWQZP$rjiGKGkV#1nHq%hXi( zn911qnNC%ydTV97C?V{x;RzR!SOY;a>v;A&{yaBl=kQG9&v$*Kb1v$UOTX(oTGs8+ z?wyb~v^R(z(FOm2OzbH8vK1Cyu)Yc!jLH3!Jq~n0@fxrgK|)8675cHrh5Ue=>7`l) z%96gfN$HRp$)PwU+?$TltB^2sRj)gcHvVchPD7idfc(t&ouR&n{C+M5y{i=F`Yzz` z<(-bbTc~CHRY;@3SJ|vc`B%&@tpUcP+RAe#+Hboo*yvlEK#97(tUHzs`&Zh1uUbQvDXH;ZHwbEis zyyyxz?F_OdUw@u$TvC#it)c&G>{8*E*X08M7?ulW>(tl=lL3rook4CLPe`-aatgu_ zWP<|ob6|jjU<)59TF3(IKVi(%LE%Q8!V^-)7*DSUb3~Y6ADIAo692uCwF>h`dDmeS zArG}WP)VsJ*drd@;>2s!Y7HG=B(0EuKH!4a*5Kg6&Y<{7^7Nz!b>dyfK9p0hoq6TO z(|Va-Qb>-`5D5TcxSW6X zN)()hEm~a(lKiqpMB+ksc{$MK99zyAcgBg@B_FW?Kxx##$_NnCJxOzNe3l8#U}Gk% zpq>&W*@tk4HZh<|)UFtd=6h%n@7JEAMgaBZ6xmF`zA)yJEzJdgnikJr}F%=Na zYDJGdB2>7m>G-GkEF|Zq9p|T>U3ho>z6!KIao#|X0%5%Y?srlgH^EL#zD*+#_Sj|s zYtoK%G!31^t4Bd|pwskAsM#)lsLFoAdaAOKrYcPtX5);8+3rWKoN9HOK|{x^3wKg0mVwOBI3LT(tFuo4`an+B-Jb}bbt)^q?O$J~@AF>yWg*X9v zynYNLq~5l#*?oW0>Syuf5L<`iMq9ZD8BJLXw+oV%68Msip^r+bUEJ+mi5e_sh}s5G zTSKCTUrlHliMn2#;8IC!!W!AJ`tElS z8GIv=84wwcoIqrm$RAqCLS()W*?_?Mc_A`ihz$OBzYtNe!vqi_yGDx|k{Y8zG6I_6 zs9UDc&i3^OBtQL&*d+ZGCJ!?aOl;{!v9A9O^^+uabezO?5BTUW;LFmC369c+GIEUg z*-mJT+>B^UUuE%YOuo*9q@}-sq?9v+^&l&@1es=BS8vgozQ*FN6gG#}f5a2shh*s4 zF-+u>%HYG%D6I8 z1yQvo{9O3Hu-6xJMf?QaI1Mq|DV#~&32T?URxS*VVGMt5c``f3p!wZ@;0Zk>7EBq< znxDhspb_;^3^f5I9lcU2Nw>QZ{I{VSC-J*QG#NY+8Qb+CKm0E|2298T&6){*=K}bA zhTW*I48McdJ$*bjkvrO361f96xl|tG@wovvxr12OgN$|YVC;QHxnPEfTK+VqHq^Zz zIS6wR)gv<9X?g7VKX?X*AO(BY)~x&J7sw;O552&7fT%yl2T8($-pJRk9#9q_-`Ib^i|w`k63vkZaiqLu|-zTM+9ej9Xe_d`{sB zZ$(0byzeM_x3IauJ{H2);>;WZzn9#l2D|g;7^@Xb!Jr<_yyY|*@jkp^hl;v}XrSOm zOb%zvCVFQVYD(D5gaY!g{sCH5PCZs0r77cY# z`HnbRVk<_eEk_GZ8J(;rIU4OoYtw<>oq_X(n1o7QcW5JG<>_54KuQD()c@Qw@DOu!36k6> zZVRM+i}j*Q85=5Rj3+$6%s1Ic;FA0pI0uL!s&>m>fNMb{J$dhFCIs#e{aYuP=`wwE#FNk_Mofz}D>q)kjYQD#+>Ake>jN;E|*R z#OIKCKOn+#+#;vgS_K)=`M8eAAOh`aA(X+%+?Y$3*%ePos&=v4TGQXg^x&0AoNN=h zNxBkt8ujQNE{&$|2n+hxQL3?O{Ko~J#3K>7z$(MDDuc7jZh){eJSI4<=KW zPpgdNalizc&n8o+aV_MTM&o(8$g+MLi@0TjlONv87z-vDvt`2DqtnkvHki0D-rnFI z0bC(c%?;chSkcwaMSAtH7`89k53e*Yv>Wx>LtHk8vz+z$38)r_T|U6<`a3c=0%p;h z`g_Q4pErXL6~a~O(XnawPfk1N@d0MVnu#fRc;CmXoX5gd#aK$)=l2=QBmhs00kCb| z!!AoVeBHr-Nv=|9wpF*mU|XS5xzw#ThM~7gMYU_JMkqO|*P20ev90uPa^=6xq|ao; zP?8hr^c!ecBn*@$zU?^?jHJcbSaVD@)~WoZMl% z4HO@W%>IDW{MHV0annza%{8R0!Cbr17QQh2M9CE+Vy=tGn1%fAR&Ip7KI%R_HqDTN z2GeX&klvj9cr?O~(C|r)umk?t_A}srB!c9!CBzZL55$h26D-;QM@F z)a&2j8h1ndnXz5kH#Yr;Lc7eg6WU2)XGb%A0NEiDOYer`*G6NmPe8wgVT5TYBlHiL z{2r4ZGWmTbA_n@*3EoGEu*Hwj@Zv5Mz_y=(0(MtCK2dNA*9+HAZ$m%t9O5=YHG&Vo z+m<1<5JJ#ZRfk>*i>j-6)G=K6E@XoB?8->Hk;s?EHhIqwzTN5Ly{6cSHYgG4S zHN|eyV0Z1RMeyuKEz4uGj21H7v8COG!3}`;MSi@4Fkss|v`cC(lHrem6zcv{c*2+P zX>6~+spwmq)(`AgEL>M`dmhd)cyi*=XU7Qu&$2zl>M~IZDegtVFvk&?f?1BU>@3W6 z{|HRQ4ckQ1@Kr9#!`IIqf=OfSF$8p7Y0Y7-0#8XcogVtEBEK~YVsMPK9l%!z4F6GY z3NuY(oLqlq6K+6nI-Xvgg?Z#&w-Fnh>CMLIkGSIB+19u0OV7pm{vL#n=aduA_V@C* zufI>>PS7T=iMnti_Q&}R8ngxWllMIPBD8E;3#Jw4L|$)CoR9Y$LSLr+131(7HXR%U z=HfkYO>mJs#xWsa$>78&cj>#RpN(ghT}n&be%9;Q-jT&=PEJ0Fx_3w-B`2{vniwb@ zmQY#Z^NMr36%4Ljc#SqLMWi`K3Q5_g4~JvrrkrIM|IVug^@cbS#;^D+d(+KGoMuyh ziY1;zl;w=3GKx^)#ziH`c3a`4ZV=#}HSR`t0-XKPj4$JqUj=Vg*hin{SoU+c6T-tm zUeYejl~c>;i{om9p)0ip&ZBV87d=qI`3kc3nx^TV*tNuy^_)P-9C&282;3T0Vc8#k zih;23VJS)1+FghYxTjIjDW}7F6Q`1R*V?$WQqEMvu)c!p%wDq^u8oOCk&1MdFOd}( z{IQ&Qd)7_cX}4(KiV)i@;?}bW@gUNO8Xt$txI2P$fpKhp`S1M5vi|w$cgQRg9K($v zDuiSH1sVD!jAW7K>FDwpT^w2*F7xsuU8dN@T@Kc#HZ9PqaFpR19BEkn4EZPaj)146 zgqvt_D)Kk1eHQ9`@CEozIXJh>VocmY1NCN&i~RbhaUU(qQDWRRIfCO71_z#y*>F#d zU%LEKT5xD>T5cZi(w8xU=`C(G`eQAOi;tXZachN)&*A8d8&J-y1wjXQt8gKRo^0X< zZ*~l|y^5kxk%|5aJ^nH4NW$iP?9cc`Y8l5AQ7NC;FXQV@sDFamate1=DlZR+r>UP| zPh0E1U|Xiv)fKcSc6~e6Z$tZci1j;=v2e&mC)fwZd1kQ>Z@R-gu{ex3yiWEn;adCn zKSx%G+cW3zEXfrbwhyZjox}wEKMiaif7$vEIh>>a5-oI%PNvi37A> zbjxVC26VtWkDXzUg{%wXI*ZsEjD6N)>o9IGdf*KDv+%wckMnB0{Rh17?Eisy2E;@v zFYAI&m8f90#&V@)e%xrzGCa5X2oOHByU8aAwEXTPB(In_QAsVNPLB(|+BnJ@rA3)kf^oH4Ob)C3lPWJ$ z&5s(L7wV{3D!e{5krh3uQ>9*k7ZobsBeT~NwbV{x4{YBqM3xrn&i2#u`&JL9$^=R? z(P$GMH+cbSRK~Zb21NE(qO?M-!_;g%bN7q*E-L7{AGz`+ATmU;`cPH4YYqGx5ZXB- z=b*STm~`qUUgc{45vW?zv*6rW&^bNpz)pBhzy-;iCS90CjG{yQL>KOo|3g?^)4l0B z^M4`>Z|+&HK5_xQF?TIR&o<|tr3eceA~|wHd*R$UCvvP`oW68+ZSG88o|ALB@Bzd1 zxqnU;!92KhUvTBV0PkbL!9^Q%X1}y9$+e1`+Iry9NbyL>xNJbL5DY7K! zuh9I@od*xL_m!#AT&}oo#c9Rjm0{z{51`9eqI&Zx13Y7|*_(az_14W)WjZhUKz##~ zg8#saDz(#i>t;*s09E7_uf*V8hLaLPCmjRJtiM&hwZ-DC&v&-|v$J&o5#3?nkfTxq z2@WKATJ8^CF_bmKey)-h2sod}OlOJIcv1*(vk8b%Ku}m;G2JoO_ramLG@$|#ff?`K z{@w%A1@<}}!#k$CCnQ7wXv3@@*w;R9Z;>0|QgRbIXg)J2+C3D|fEj%E`R;t|xSp`*#uY#?p)tOL5HPL?XU0l3vI=t4iDfg92u35Z81 z{DLQ;-zPm1!S6cVgdXWW3ICGYz>r~XT75{n&}@uhGD0t(gNI|rn*Gl@!IvEJFDUoI z*}$<*ay_wgRr4$#<($QcI~MX*oR3;MSMk26_^lXXr;$~h+S^+By11d4K$ycERXY&x@ zvs#VNvMtjt9J6UK>d3c&zM?2p(nYno0~4$^kD!N^_ZW=FKVM zt{|cQOuxs6Qchq7ZV$shvFO@`F8lN>~t*oIMNn$*!!uQ|+fc5Qb!|3VD@(tK|7JC0G zx_vmJde8+yK!YcnUP!JH`7taEHcclHfe;nOTn6g-@!mv7ziHZijb^+J<@YSH{1+w H^rC+N@zuTL literal 0 HcmV?d00001 diff --git a/utils/__pycache__/logging.cpython-39.pyc b/utils/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3b71dc834b69ae83840174febe3eb2468451bf GIT binary patch literal 2522 zcmZWrOOM+&5EiM2W$iObyXmb$+67iY;^xo?Y6RP+*=`dcb&6~j1=I}$tw`H+WXUC| z1nbzR>RjAlHwCYHpBO&XxjAAc2u@a5so-3BpeWSc{vPapN3h*IV~itlo{-4 z8Hsl&R4CYu9AyhSQZ_A!a=I+1%I)Tr=${zofF0T>ZQRe$gq|Q49F97FSdq&$&-D)O z>8=4?+$awRG?Rk+dP;ZiuHS7R>LGnieH%Yzl-zD84L;uxt_Ed})2#5NS3 ztxp_7A>vht|D~m&U7A~S&0QW`nSC&^1|Lq)98X=aVPoRX(bSuGXY_gJ^m$qz@gT2_ z@tYb{REn?~84N9w6Lw1;4l9V~D5uAOqM{W(Eb^m}7U>~OYs{V954_PgOogQUSn5WY z79>1Q=tY6BD2tot1>hK?W_SF}FPmGb$Rr!mNIU^1r_X3!q=pSQw`!^lRL%-ol+ouz zGeZD>+5?gC>gMpXO%iV2Yj6J8-rS3jY7hWocPN1b;BcA_JJD5zNfoa$k<^9&=*mnr z7Op9WmnFcf4{qcDYl9Kh>{0mufK0ja1(4b><=pS=K2%K*AEiAwM>Tg#4#saI#|z_d)>ppXq9dws zw}0{0tbldZC|QK$8rUtaLua`@4r~wmXc>9PK^WnGJ=aI8zy~P68)zLyP4g|-o$1TW^R_gsT$i8z_qS8<39}q9#P2|g!Tk#0ITOr_6;vaG-F*xNvN$3OMW9{-I``1 z?34w)5ppn_7h!6&)!v8C`z-5+utnon=hk2nG?!g7)n}BaJnJ9TFucQiqHW}%{?_WZ z069>3qxRzDTCP@K3q_b;0hjm&bO1F+TX-?1fzE3!v;*7f;@99huIB=!gLLv*Yy+cT zi&yYE%&fV!W4%OEOQHdu=$AJzv_tq?gTZCm2oJ~Vbu(C+`kq(w#WkND2%^@plG{I^pOC8ZV5bt@S--tbE;udra zym+#8fT!A4C~Olf(M!y?ppe+<0TRPBj~hERp}g9x$X2*@OQEL8 z0G+%-X9Oawnxf<~c}kB&yi^IAENmD%FWO?-fm`S@NryrOf(@aXYK6)r)o>(Cu9cr8 z8Xifa923PBoFCb-YYvTrQv>*?uyWD*)$>Mnvc8froIXkQ*CGf}Rl;QgIGzHPzZ{v?S m?2MJ>pMiOL{oBKmR5{(&5f&do=RoB+^Rt(2f8D?42mb?4;Gls3 literal 0 HcmV?d00001 diff --git a/utils/__pycache__/multi_port.cpython-310.pyc b/utils/__pycache__/multi_port.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a38934394287ab79c1d35a8b65cc55287816203c GIT binary patch literal 663 zcmYjP&ubGw6n?Y2Y2r4mKMsNy$-zsIZK%>7iijbNP}Bsn4PsnaHanAMn%SMWGiz+o zQ+gGTUYn!;(p)|DUl0V}L;`)w_szWbzM1#tv2tg}2ehm2KPQ(2;J0;dn}g0PglZsB zphUpPu8ts?Gja$}mwHz)aw!R+S-C+K^qRy}Qy~@4&r%!M7#4sUqJvNrNrNWWM3Zv| z;jWw|`SJK0e1&U|k`*C#SHa8{qDq~WOF-8!c4vhRJc1>84-Piqfd@;>g|H9m=>pH$ z?*oFm*pzqqu`$(pu8vwQrIU0fKeJ*Y5ZZHm1iQFTQ^Kr1yXh*Bv^V|VqM#Km?QOmOP1 zh$Gka$iL((r~M0@cn(6^v%b&vd!PJ%INjb30PXtw&*>Eb`0XxtjpF4wLN}2ZFp|T_ z?KCF^9Z>?zWBxUaJVs&&YqzL^UPx>`%cU0OdFCAd4HG~E(Lrd2WWbOcV#ozWcx$>M zKTf{ESGe&RSrg(;HH>T^>Wr>E0;Yk!duuG<5v<5NpjdzpKCCbn!vW~GkAYO{Ju>sK zDF5n1b7ss!AGccCq}g14;%b`9%l%9it&dBtjgX~o9Xvfcd~$Tyl4X?3EX`+9n`kC4 zBPJseEkrht(rBtUj}}sys9Y8ku9_jWpj>0Oj(TCB+{o;Pmf4&eTW`NeUiP|Yb~EmG zUM2C_pxb_9{fQ`<4Kz1SCKtvAasPGVBq-gA`^lg?jJxgA(}DH0QFinEpR`O1j?*Rj z7pNV~^C#*YkB6EoJr-JA$nmfg?NSt};p3$dxgHnG+=#?gyubKj?@uvt;+m*jK0cw{ zj{tp00YhqgDL}MjERflm7abgMZTd0q*1pGmO;8jg;-1Ni#JCORe14XP+Af z=05uk6J~B~GkGqx1^UZgj9_U6BNi~i#t1eXkya0Dx$%_kLF%C+Oi)&C?b#S(rS*MN zZszQ;kz44ub6DIn%z@tN0QP;s{=n`U zw+!1z&+Pt*8GAKmWR3KUIF;Lw)y!*{v63e$d$nKHcGwQdIjiLS?F2c09g#EjALN+g zG&~3CrB_=s-aoM%{Hnk1^}>`t=mnwhC13AFoivD}cB8>@y(>48Wku0bXMKY&ilCFW zmifD!2)^?$=se_|I7+=BN_f|k-j<*GlE@_S6Q0IA_2h=1@=k~z9*^oEq2={qF9A;} z=wpAEFQmI&e_@FS5mb38347?DD2+$Z+#c8}{f!_=9VV3E|bPIwfju$AVp9&9N0^u2h~PkS;NapdhWk387);Ym0q z)|6E2zTgjbA#XeWIEXfmF3e3YfOm&Wa~rO@CS@##T_52Q_Gc;uPQFI>1F5wk{%6|V%*x-b2xtTKy$+&2-g2W@ z_(|F6-id;Dbu``Ur9qh7D-CPI@zV+=-F1m!dQ-d%^~rl^8q+h(VNKRxGpu1YOdC`K zJtjNNCh1qJva`%F(Q8s)lg%>c*XF;*pJpuiDui`#k6pAv@8C{W(f9@>WF%v@W{DY< zh}oJgs-lL}T%lx!|EhRNG{nSXwpIf*DNccML{m)RU0s|O)A(-)6KQ)^QiAq=0UZGK za=LuW>n3`a$2xccza!Y4^r7wAC*D>!^oOX!S5{W|hc~WX``IminU|5fxS+chS_6Pq zjPK~~wpQy)R|U)B5iXrMZu-fppDt>UY#q@^S%fmqxOnL&5i-x^zCmrZeky}4QfOVC z*D~i@;|j-x^dTvYkt?2k4$3#E@Jg)Bv5=+u^7*ohymtMQVQ%CAu~fK>Sk33gn> zXpp*Yh^K_GqRRBI$+GEk>DndNSou#=BJ5JrBCO1fP2-T|>?!L`fHpOK%aHSU*|OwC z{K+Y5UZ)OIE)b3L%2BfvQ-c2Iq~m3lO%5!P%q(59-~RgcR21Ge|UPrHf#@9t$itsJ3#f7_CS;hsAoFPwASy@9F*4Y5~Wn50HU)@nH_gYKLJLQ*^}Aa z3EEB!KDmUU{|}+oI3d)~KMb{t<@Cmbtpw42yj{SQ%6P+(PrgadP5rWD&Fa_tXIn1Y zcrU+6ewfhlGe;fA#E*Q7UMN)h7E#}(Mnh9!ljNdKtqL5eR%{S4l9uv!#Q%LX*_-7= zhP6&<*-3#Re?X&;1jF?#-rK6JABjG#8;DZfE~aJ zjnz;1+9B)%1u&0hMdi?;FDHPP>S|^#^X$|zzf?|;FBGhunF}igYqaV`H9;lsi^2}+ zfKO{v*f@?RYF&z|*X{a|C>pNY-_)*J#KPW;gQ#WejPes2HHpH%BLh{SC4~hUUjlfS zbbJQpgwTXD8zp^%eU;hfMOMXjAAx+uI_74=?<#KP=1U;(fyj7j>>=qK0`}=s8a_`U z1w4go_t65EjZI4-Hd4Go@XCRmvpK@u>Z<%PWPC{jE#@uU2?5~-bleBN3rLM`R}dWoF-B z;tkXy?_s38hNh?_IMm}<>uU6a1htpQkNvz}OpGkRQIi1NTIF9bAR#w8fZFrpC5Psq zy!{#&bP+M3^;|_qG3mOGdS0kAKv8pD5qBVWk(gDL@di;pq~n(TSa1v+@y<9k$7)!ODQ9YOLcNCzCWu#MB?TFDpew5S)S_69lU+EMbhNHe<)W%A NXU0LWIo9mTzX3}1!7Ts) literal 0 HcmV?d00001 diff --git a/utils/__pycache__/registry.cpython-39.pyc b/utils/__pycache__/registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc83bf69ebc388ed310c31500e85ee492b196714 GIT binary patch literal 5584 zcmdT|&2QYs6`vuwyWEv5OLlB0Za&7UQ|%_o)~M40b#1|JY_~z`HDEh2qBX<6z!>pRE1pR;(wrj46i*E=&d&cMf!U)%l%NKHrK8+!^in}@6DSx zzxRgSW5*f{zuV4V)_(IH#{Nz(lRpP9S5cC8QAs9w$X3mhuX6ovtyIbPK(O$M<22U%P(ZLJ{Gx&H$THUYZ>>1wy*L_!7pzPe${dJ za=zclZL~W%G#)TjWbdXkvYOjpGFjWRcR6R$mGy9T^Vm*B(~WR$^Y~7cvvfX3|BiE? zMT?x#U+U<2GdbT_!2H=ghxGoKMb-tDTj0;tKcUHq9GYPC$^Dmd=NbP4S2wfy+*xA* z^!$qdp5JG8nZwdkf8g6cuSG>s(&L+BN6<>|@w?r}nE_)G*Hr>;j8I za6UccyniI;KTmKz{w+D5fQ{#2V@NOUFDuyFh3>zxdE&ai)*XbYcsK|`=_|1|h8kwo zAW7A>XxZ~Z=ff+Zwo?aLzWlZCn=7~Y0P%B(sOQJg|oTEkkf zuFW&F;vGL7sA$5F_ogiJaKncsVVKxZN)Pv?c(@IITk)qsw0>}7?sNlKcf2+CV5%ER z#cJI2fxmsA<;G_{U+ZoLWv>ra-6YdLglWM<>%A=2v6jN2hHPuk=gl? zS*LiFrZ=NH*`=FU<$dgD{3H>B1a7h{p7_H5tRDuwAPuqJ&3>%X@h+AAC^AQ!J9kdO zXIYCEuLaSXul%UzdxI$GM!_ch!}Ft5b8l`0SKDbEhe><1 z{t5ChTPYjjRxf_LNmKSy6krO+-|B{iLx-=Z4*C-NP}H!vK71NPq8DD&N3KU}FE_R2 z7B(=Ya7*tgs^+9o%VUr1xH?x-df&FSx<)ehpELC;ba?UhdzT)3NZ#}ScHM{{d>94S zj7vNiq(PWGDEYTxzXc6--kORxJ>%IggGX`|6snhFDoy5w~t%zwwK^;)*C8ueoHJmRcip)dR1Z=C%>* zD^Krr^S~$r`*-}L4-O*-Tnph`- z?O1@(x^-A0K&M?afjkI-rI#Ftzr$08bV4-ArpuxeM=*kTe+YaOC&6KW&{x{LJxWJ{ z#AGCs$8;RqyI~kV(eQwnoDSf^ns1pUJc^$XaF6*@+EC){GDX!<7e!X%pg10f#9H#q z&U34%7#MR)O?!lW<&S6r?e|cU@1e?h&O9Wns5vEwo~3I$xB9Om;bdM#B3*pWJFTi( z!1!WjVnsz=%kprB3oVZ*jVwiDY}A%hIEf#w>G_I;s_#-ocxbBEh_UOzM}yS!LcAsV zz5oO>f7wF0b?tK0RHjc$BJ9%3B8<%02HWR3f5!W#L0g8t%hVz|Tedonhnl16O=_^z z647U=npR5>O38Djk!g9hg!^yw9s8s#vnb)lAOTH=j?O71{jytQ+LP835@N| zlhu@UG;2RF^N>`Du+n_7BH#pRWXJxNKw~{Yd|L3rKbq9NZ zp4t{uTdBKW$K40Hy{u7N&R@YjOS<~DRi!&i&P@AdP*Aeo$=wVE(@u&c9N;Um88cQ-%c}3`-2TNGj8T49ZVy8dQZ${#zBo;Z zQy5Z`sr-J}?V0vgLy>jVOH|dUBBL2k)&~RrOIYK1+PMlBd3b~~n-mUFEerWWm=rSr zO)N+ceQuc}p?y!gojwS4Pw8ubuQEtTPNQNJW2zIGAuv$sZGP|h%Mt$jI+ zd%*RT_K1~ssprPij8wX68%MiHlopK)hOQ!IeB4GyqTz$j&1PqgkoN!Re2pWV5A9>; zJ6|q%+Wksk9eUt77^(C?+~L&Q^gd@g%b{8Q=J2vA)G}H79aOCu<59H1hk2B0Q-?;X z?-O-_Dg#%AL!65`)!LPm-oud(DcaoMqWm}Sd8b_7I4dgkJ<5U94@n0tX6!TzBc~~+ ze?gK*eslI{Ia44+6DUR0ApScLDL&R{@V1}fLye}LeNCUi`wm}cKr?=weew}~c%hMi znLD|)U(p{*^xWJgklDjW7|tFyM}7F%vw_bkuK=O|K{VIA;p2yN_B0qhZpzBOOCM7J zMb%DbT@u-`OX705NU>D#c4jTD6ui-@7u5upL0=Y5NOyVKwZg&8GtoO%RJ;AYAIYNO zc|!p3^!66cMjS*f$Hbo%nl&r^q^AO%&LxEn9^VKD9?AF<25Wd{Tp;;JqK|an;pe}u zTG==}=%G_)CfKT?Xft01fp13a8QVdK+DC|>Z*2H(g>dlug-6 zb@t-T{>oS2@-GyYy*ctGAg$k5o;P;Z@%YRVzs#IZmPG@3%XQ3C*H9Ie1h;qGU_FDA z;GoYK;rB2HC}t+ie^4Pnyb<|VOi0Lx40b=ur_vKMf@1s|#1@$o{%4G!t7s`^J@4_L z8=7EI)I3kdJuq&f&!i-85%nXgeoU2C{}xgD;EgMSY<;Qv2t+cE%H__fM$NTdyK%g6 r+?~ZJj$b` zrJ&&}ZA-;hpkg*S=)_q5`#1ZTof&UA81xX>-Tl|}a|5BD%6Vu_a9+X7zkwtu!5PZ& z7U$Yl!zj^`#vR%=2u_-$k+kl$Z48#`wB6ZiZhMgVX;7bA8*6QE*0MWy z77KsvtXyf#o%P0ZF8z)2bbEPe&tZ>j+wPU)+R|{`h3!G8lh;~8z-y_$@$8LfIj(HK zb=JOhEsfL0rR7>zw&zYqtOaQ>mC~terv#X~%B-)%P%=<5R5DUBRx(kd{xf@|3^2)?tVB$EW=t`U@ub$)B$M10*4+cF1adZ1M|EweQQzdt!1<4C|`8SXhrMN&l zyuv$erD2q6seXf2dWO?hrl;*&ZPfxc8g2u38a9B|@Spxg{|vU`JBrB#lMP zGx2Ydc$<}kW#Yg}0<6Z9SP*gCXKFeXb@eG`%FI-kdXjIY$<7@Z??1$B%YPu&f$zVc zTs~b#JY#%KIek~HqmtSs-NhnX9|SG<`n+(H6kzS~{wRCR863()ZNcjUXt6G^wCpHc zZXEA&IdgpQP!F0jJG28Q3}koa&gW6!Ex@Duh6tSh_%$C@6AcS2)CWWIL)VF?MBUEhDewm5G&+ ziF0UC&$_$s?9ZqBQCXHEknQW+>Kpv04~`{(;0DxOfT4)u78QRU7kG$+BBTKgx2T9H zzDIfV1xz%rv3Gm%I0T)6nkg8IcK8LoV+^PXcHtgJO}rcJ@js#%IO8U%PAG!>!Tcdl zU2K)EpWLY+vL|h27^5z#drI8-(N#s-)&;V2BTXb#!bJlyS(5shNorgSU!-&yWtM5d zB{!DWIo1i(_<{5U+|T1|S}qLJrsReHHVaIQ*}yPvTWJmp{#pjm8hhm z%p5MOh2*!A3u0N>Ti%+|S8~~HTxyu@yP<=TACKQhiYIVVe1`QIgn6v}9`qQDi)Amm gjrQE(aVYZ0WS*2<2}S#sUHi!!-%WxNoZw0L3y^?)lmGw# literal 0 HcmV?d00001 diff --git a/utils/__pycache__/seed.cpython-39.pyc b/utils/__pycache__/seed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e0b3222c6264bb5cc9dccf395724f1378f79564 GIT binary patch literal 442 zcmYjNy-ve05ccJ#4WWUTKtfq~fC?c{_b$*Sid31L5e|tH+0Iaju`4rc8F?kIOstHZ zIEM=LPIvd+`SbaFJj+r>vVD7h`l0yclVb@exg|3f%?GrAFAHA7`2PkB*y# v6lSC;=rgM2vGyC)Q*tg=L)&k(=Z=Ttb4M=oq~yBoaA?Gh-=Xo`1W3bQp#EpL literal 0 HcmV?d00001 diff --git a/utils/__pycache__/transforms.cpython-310.pyc b/utils/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4576b9a153e8b548e94c52cc0782b745f67ea4be GIT binary patch literal 15264 zcmdU0TWnm%c|NzjaJi&b5=D`ce3dPG9b2&-qe&gdqD5P>T!pqJ%Sl-|+tr>WwbXK# znsX@Gn`P=oY9%#wfD~!drWccTT7+#{1nonSqD|4f1xSFlKoR5sMFRA})ms6fFKrO_ z`~EqXJ*kQoNqQU7d3OZq5lNSjF1Vvek4kzw(nWU+=`l&~KziKW zhV(W`??k%fZby2%q$iNx;qF9wr=)iwJ>l*`dY7apk)CvSBfVSFyOEx9r;(nP^c2#2 z+`UNemGm^y``kN_zC+S`kiOHs3+cNgy%*`b-FuL}N7DO{zSrH4^nOX-f%JXu{Yc+0 z={sHXq)|R_og+V2wt~#DOl_wwQ92- z6lUu!l$llSm7ti!Uwu3nPg0*~xF~s|+H%{=x?~bZ$x)qp{z$WdN{&`nJ+D!19d53w zV05Cmuhd&!TLmL?ZM9r&lB$xXE~z}&{oL!ms@D8lJa`5z zF1AQojRVnpSxC&__8vx1H|9;(n73ThwJsa;wrjf%{sAaAi~o$9bMyGmx+56Wf+B=# ztRT7==RJ%22I1c|yG8{VK51;4bDL&R0JSY|0X_0C8ID(PE~qJ#+ceZ3{QXaD_Q}(y zyt?vEH@wDD`!wcd;f+(rDi<3V+lyXhrLFwxnPxqDaJshY`R(OOtL@gkD6?{^)p(}W zSg!i@)2n`?>77nmJ-C9gtyCJVhF__SlT6ZEFgv9X0tcgqTRLW=AUo#Wi@-qWDQeO`=%Y2@2>*%k!(PT0O{k%|@**<0joMTWT7mf(*J>e_e(^ z2YZ=2TB$77{f1v(u2i-o!^g-I&jmy*^NaBm}XevBT1t>O_Owy<+vjV$W zZ|TuDg^d38bCn88R4Tjiwz>mBXIs*$JV}ia$0r^mPAb|no;6;}u325HimF2d(+GM;~ zshnS}HuV#=4;86<7~IF;0S2tRC*aV_3>6WPk;@x-D{tqsMOz~t`KgPYp22_*E6 zCqX`JYBZ^X$nNhDt7N~!QOuiw{5CB$h9jQ*O`^Of5X1zPt{4lZNQivbd={uDX+XZY zZe0Z@_|8&hL&!yU?YR^v2}pu?Co3gMivY=M1y-J$5CcKzP+Sr`cj3OlQOp|!tC$f$)sReGoVFR<9&sLH(yZG{ z&Vm&pGZP{+E6B_#)+tog_;z$L&c4s$zEMJ;Wja{es6bhSXvCm!l!7s!v94Cy&8pvM zw>C9@>j8Wk&j-O0#*uM@P|FgO)NvT2`e0PD|GY_*H;bie0w}!%a3N(!Da*QMUNf#3SIjHcf-OUm!gJY~Qxu}Z0ZjH20FQNA`7mxl?VU#FzGO&l0$WE< zfW6#ksYmd3kWuZ`mRru~>WS9Cycn2hRR&m3#WToOF?30tKzH0MnfIELB5V6whpg=f zZOj|DfG-h7+C&QFjgnP#1a&p!r7q6z3~u!S1Q41z{9XScf|SN#vk^4aE+j-SZkifz z#ML0YJ;i)cXFC%UZZO;;0Vr_$B;HHIErJ>(mvJuRZGXY|EWpRq6s+?-#x=8R_oS>C zY&5IVV9RK*U1xW;AZ`Si6qz?DAp~*TW5iikfx0Dgw{evX@OorSrYl7YXBGCds-+9~mm;f`dW)tT06 zDzWp2qcawUe}@{VqD34+*ts-ItqI^#-KEX=NLddqFYzmx#nA;34!Tgg;;pxRo$yDuQcn>xRl00=3>M3&j#7E z^~T~^UvVk}`|yjORAiCi+oPW%h;o zDvaBVqhGM9TqEw6sXm2QH?vx^96?-F8JuAteaRg@HhT)U^amo52-Dq)S^X7cNspe` zI^?m)Ie?7h=kR?PMpfN-V zq9qkM1*V^@se>HC`YGX+(a_H=SXNGzdWi?ms zXDr>?=UE$d_r3IsJ;wqE8Bi03jxu#WGh#}SE2?ig&W+jhH<0CVwVtr)4yZr4BGOu7 zz+}zcf;q`)pJ9u_oJ?(DPWEC3xA=Gx??b;+84^0H@~nU_L>J?{XK^R1I`vCT4qMax z4j;;zZdbBqdG+yxNBbLydGyclW-o`iux*yiTieZ1DdADGw}!zy8d~O3d$>A;Q&g;V zXs;#M%(7y%>5jN!!y%F4XFAw!&qlso?)`m^m}s+d--!JjS@z3pbRf&dD(vGdB8)-n zO*4zE8`I44HXoiuP?N+ORB5SY?2z_`4|W?LaVPMW*6}J)xJR z{g8@u*pn!(Nfp`>0tx9autEw6c#lnZjs@mOrY>rW8Qij^5b=rl=(K~?#FmTRg2}oW z;TDe5?I^d8p`79hyN5s#4m26R!u%Noo!!aYa}#SgzmR-V>phKyFwriXA0Nl&M~Gn} z89282zZcInvRy08#s6O8lBFGP@Idx?K%}%2y$@c6M@WyRA*zeC zAZH65$j$mWgs>X-aJj{JM#tgAGnevi?k#a+qeNI|LKjMgwMI0-gFS+Ek*tk01+qd4 zOX+nI6$EB6)WX*>l3Vpk4w<66oa4(=5>CEGZNN0A1O48&Sk`61^n1th`%J|0(Ss7z<@=m_=5eA7WO-3Ra9T^i@v%b#4G8i4k zVqN-Gq{QScn2*C%HDynUTzJTO5UC0C0JZ8?Mhd0G5N9b><}IA1Wc7F1w3Z^{CY+_h z{>V#Rkras*m>hNW0)qFx{0F?gf9N(O2F9A!Mv#a9xKdU1YGiVV4YI~6WoQzxP?CI; z`7;PQJ8$K!BCy#Lks?1tB$N@M>z+>g3A*gzxuWxfp?1+klg%dJ1 zE%hcdejP!L_G@%8&YNT5aoV(g_N@gA2NXtSzs_!D@DF8yT{~`x85IXxVMYZ`#b`e{ z-1p4*t8kNQ$rZqLxj&{HsckF;ByXlTZVE820bP!xdu*R#mUx%M>JJWe@GVG})voIKZM z-Q4GqvzGTq`~vbuyV-6&5*KSDy}6ONSSz@8cf^7GI%lYtAUI5Qp=FuI`7z`d!u-<^ zQE@)RNfFQRln(H<(J*awM5%sL3DeQuvKw|I)<%yFA72eyZ-u$>`H2j zgsW2#0qV1iQ9%dxa`pA{xW?BT%uST_Z}PcWQCIjsf=WX{aFxZ5GI(A}Gy-!e$U}c6 zazbO1u*7j1%9E946B(>I^(nVX7~&&R=98)qR<aA;$H&y=bFPjF15*a0ttjxMgcZz?0WjxsL@#%0ULmjL6wJ9*)%eO52VcKlW1P z_;W{Ishl`=^!Zn`iO|wMjl@bRJB1Y`9^9u@XTdW0jyqs*lk!f|_qo<8sp& zXz3fa2rQ=hLzEe=yg8XTU~8yTEXLH(`S((48@A~pR_y zw1@<+8KLI34Dld~)0)dwPjM6?GyG2xmyAIWDzO_gDAC!~d%Scwh6rrlp}Ek5p-CUx zz|bW2=Ob4$y4z>i4!~3nm5D&&vP&=(9*QBRuHppfTV{l;vRO8{v*$KeJpe4bAg6E@ z$&sNGM1gZMG77r8GWSH)Zuaz9L^|cZo(=~Utd;~+m)ML9W&*FTGB1YL1ey@pfE5={ za6udmY-(;ew1|LlRzsBLjUBWLnODP|y4dL%-0B(vSbVg9w9WYy#t7gV>xeGKBmNhs zQSV_aG3c-)5)&^&6E89DGC|>t;!by~?s-AJ*2ZRP)$@abSM^s#UA6-I>}p+OE~_Qh zpQAM;{(!GB_LmHXNwIzGF4YouSoyb2$%ZsR>48%#w`}#IOm7ZIg9aj2lo_T?^LYD0HoZ=d`rrqc`Ut z4jnS71rd58_B8Q)gK|L>E6yiwSdp!4Df&-??NxUP2aqet&{ETeMF3965}n_wntK=6@np!`!bPzdQK5JS=hbm^KO8dx*gf~z2rLKBFH6w0p2k%~*b zN20Vy`4Ea=Xb#}7PsTig`o)s_6V6^92O`)1_-UqAm^Q3U07(=BVm3& z8=Z<;##%PvH0iGlr%)D4rTrvo60h*|LcuD4zi<@#M}hgtc$cW}v*Ef_qlSc5}6YEsdV1PW;r7r{u2W-Aw#6BTsG3FiY7{S|(Z4#w~d` zc(5(+KtW}#HpD>X26C!A4ehfr}#~!sc#_vo9OJ%ZBq-E6GNDMm1O*oyzz2zLY zi-q+%zGsT{y%m_p1N+3$!_fOVh^7m5IDp80{bW+XA10X}MkI1(g=B7(juGdncUV=T z&693(W@cF5yR0wt^Q7vVg9P{nUw@K8GHOX3oP>|E4zUuM+A?n(su|oKrB3YJq1hy# z&;am`D&f(V1p~#^EMI&ALA(1Yv)8|sejwjS<#!O_cNrUotPW*Iasv)8 z{lOCtBwn(-y(mG4n+=|a(`|#_sMt~>AC`cLB_;AbC18N9k8I#e#dC18B3@t(c%Op` zCA5Jf%y7vWOZjvjOga6b3E$1&UAO05{U+&%-i&kfdvInYF0AYCuuJjjM$;}u%plPd>@@2$=Olx%+U(H#qmB4=K*l}?K$Z9RX z2V(MhR*-F z*_mZ^uO5q8!!~J?v?9e;62~!ucvm)(5XnbA6Tr|%J_0zffy6*yxPd?h@}clAK#-p# zfO5{QeoW7Uk|-;154xtRy1MGtt$R=1x>a4AZFT8!XW&Ilyy_bu`EN=fx z%CbDmQ!Q)W{HtZ#wv@H(%)83ujl- z6PD+Cxr>&UYizmT%ugbn_X7A0EMta)Yh4d~-??QUFw+HDxlHQH< zjJFr*y^`L8^gizFYk3ZIGqLLG}Ro`#c+Q(a~ zIvk(tv^x5UW`H*08cyiit6}k(&NGd+-_haNsg7Q*wMbRjP?uI7?t1igv*vA2dGg zdDd@e|4h?wE_Kd4+ipJGZZ6k?#+lWi+49e%?;l+Om#WofyBSoglWb5ZQETOBFptI$ z2S{{KT1dPNkpM&n>VmCz2X@z5a=L1dbpb*X$4?9SdMAoj+^__?+I~=L*BfEZZ#C-; z;evFcf;vzt%r)Cy;}z+U31*l(Uac-Rf@aWIu2#1p!_OnIR7p+RYdcbslQMO{LUJ*l zVJxV&mK5nUl-Z%vYP1b6m5|ZjexX`LiE5SXj#;%eZC9RUI7KccoMOk&qVm)f0Y@*s7^`I?-iwO#%?u7QxRn6~sV87tGDZ`K_(cRoyXX}2MaRf5GBOvt zJd4}kk04Qqf@gcqMFo`*bcIUfp#nvnL!oN`>p?w8Sho*F5+U0TjZh2rIzniMP0cU^ z=xCM;on||LUO?Ng8fns>Kz4t3*zH?zCp;{wwXFsNMoT@ll|=dXAxH??f5}=3 z@1l-B@9zXfy$y1L${Pa~jw2$5ss#VRNnI4Pclpbs~{h$tne&`6E z{3mW}?NHKNBQQrFY;_RxA3`K9UCA4sy3El>S*tLU?P^TPR?Ol<%n-4WrbS{XN`wcI z7^>-%UG8OaET>9WSW`~3u)duEsWg%~QXVNb6kDHzQnX7s5u`d2nTx|Ui`zepAdyAY zaF*N!JCev;B$0WM$b#mW2u*dq9bb&I-*dSAG6I9lP-LqL$d9BD3?YZYEs(;7Ug@-I zL9^4oYP8xg@6&ibOdQEi>5ZlGP-qMfIMf&}jqE*d66Mb$=t2Vb0A0@5>-M^H*4!d^2FpJgkF(6U{bdo4Ao%Ict+&J0H^ zw}pOzO|;^DM3?e)7uztRTXsd_>d47l9NSsk9}GqdD0%Id2r@GeNlJt^YfcDgUR6dW zX}S!P$p?`iPjx#|b8Wa#vO6T2xBonXtWd@R31qX*<@|$BSYLyPGPN6)_zc9b>-41T zL?mfKWkoV)MDiNDHXxT{$P`HB_={W~OyttO4B1>#yR6G>82=0Irs@S)m{fHbd4?-S zNX>+7zjAMAzk07>QX-y@pa~9SJo6p&!DVSKU-9^(?D;n1SHvC^ z(=SDG3il+1GCLx%Tq-@scSUogsx*-LC@G)89V4fsnN!D|xj2fmxU=F{6oZztxG`v5 zuKNBvSH#n)ccrGpVKa{b)%| z;PZ!gmy@{wDB#hf5o=FDM)eV6zQ@!0c^YtbZ`r-p6&tXrz>QuRg=j&V87B1_U=ExO z*Q3;{-PWdk#fj38dr7+@^UfuA0oGBbMgnocNWFd(1P@3(LFF`VS%rwyAEwkVfq*4U zPc!_7u`~l@4lwz7>h;L5z>LKTv9u`ahF@D=X*Dnx6RE@8%S|sh8|Ke8nu})xO>GOE z<4--WsTlo(LDYPj`GVGW4H<~hN-@wIN~Gv8E0f;bX?hbtf7uhxWKrB<9KC|2P!~V@@$zV(Attc1 zG0@}$B|ak1P%6^KZvbfYZ?lG{7z}}v#h6}!lLI511H(~;h?1bnkC5m^ieE=mfrPra zDXT7JblqU6+WcIuRSoD{r{g>qQ6LrXwVp%=NZdtsbF;wq?18dh2`ws;G z)YS%CeWrBV!g^QYcdxFcgX(i=prVRQC${JWO>njQcfH-PdB82j+NUIR#cG zNtxp(qdDHnEE-Z(ttee%by;wDePjzX+N6bph)4@FaSKbbP=Q&~?*r0YIvOClg|Tc~ zZ!n8M1V>`lT!=i2`VdPbR3caG1LA?nO>pAt$nqx36SV#>#8@CO_2rS3iO`^r+SEn{1ZtTWW1*y1skv zzEu6?DNEpDJW>B|AYQY@7gcXIXq8&PN zE;(@JWS~M^IUmmo;sXG~cfkz<43$1U{+gk|_hLc^&>*q+Z@{gfA7iyLnajF9tFzya z+pM@fI{OTBxmv*jjoJyv_VS`=JKP4x`A~Frh++Rnc<7fASaxLCJ5!j};V|eI;%GJ> zt)o<^Z?ihlQ3a2`s2J<+y*S~`;`R?Bh;>wati~~}u<>O!!t!2DG?m=E5##++cu#+Z zf$%V8Bz*xt#2wGwYHe34Y202Mt+VWsOxeek(OKEd`W`m3BJHA%LBkyW?~PaVaFs^6 z_}^A+C$GAZLIM9< zHE?rq;e4HyJ&1%?P6(#w0?TK{gO^Sw2CG9^ME?#N+TH3kxee#*%6NFOl+@;ug()Y6{MP-M|Xh zzC#@WYD}ryhN)779F^ft@aNb9HFETRA_3qBxFcklbhSVNaxzzfEc=liBMZ)7P{#(n zkQ@*3e!ly?KjHO@V>f^?7!}#VDPj+yTGNeMY&kJk(?Ku{d)`KViak5t&jU#{p!FxX z^D{&utcf;oxOFqdjT?9)>F8o;Hfk=zdw94-=a|DC zEo?jHgIoY4*eG^$-TV}+R&M`XQN6i)<6P_f5xhB;@lKlbaPMRx^~uY^1elWMh_x|h zmLc4cU3eE=#PuJ?qR77=d-qwn{_EK2nf62a)>1oM4_Pw*6Ph=Ae|u(dM(g|lVZVBF z{hwo^pwc3AxrJ7z7U(Y_KP;@an>^@YWR7Ns;9cnAG{;$Nai7c|_9+$#6pYK}EulaT zY2jFb9L|ElQG%f{p~q1;CE}{BmzZ%5L7mOT7vsE1uFn~S_IQA;(ghm_Tc%~N*J+pV zj|mAozT8f=sxG!stt#Bk@fn0@zdRSL!i{e5IYA9@`TSxzFtLh!yx77q=l^D zrcco>aa(w2%H~na5H8sWFVdB-VS>S28p8_6j$j2qfRiXcGjM^e3fHqE00C%h}cPF}~?%0CM1AnJcQ?R8= z$lQXxZ-u?@E!{1_Bzto~owL5B^c`K>(AFJ`^I^P!8|$U^11P&SD(j-`b6sb>#C~*3 zalY9>Umt51Sr4BZ+Gy0Rk3%Rxm&AK0qrrRX8V6E08KiUmqexcDk>=}Kd$A#!ua{YD zkwF^)4kmRLMM^@a@yd$qYa7d_!y<$LO{boi0->{9d!;gIWQ;o0v+c2hWT>k8Jpac~ zX#@n~y_{uP0S?s+)lyi*?3dz%kx!ZlIEF<$S*hvTvi}Y;L|te+6zvJ%!T}({w3@c> zAvUP(4ntRv0-^09K7rrb?yMlERSvA*>4u^1`wt+v>}*6+<)%90;^YF%?ki^742xcD zoaH{m2?h?yfcqt|HlXA%-e8N;J-%db=6iOCD^92$3n#EuhJ8w}+P7JpGdsBJs91lQ zfvEZ{V_7RyK?tG+ah4I$`L3a-ag_cFYx-nT-37)HOYS|SqVr7CYGx2CVw=rL@lgAC z2aXd$YX8xjuQ2ac210roNQ1=Hf59yc0-TD)1A?^_xy;3(o5k&CEVxmp!6r~LG7)}x zY(NxjJe^}*9659TCc$28V5$}1yKwv)c-}*{P)vI?54Oi1ohie3IaE4&ahNoD0Ix$f ztW7yL2>jY-&UnU~!(#FbIpIa|J&jvj?y^QGn>9j!GL_0H$9{|q%Ok=$ zXi}~9F)bD;r7yCokyAW}6*!+v@s4KBXj!2=CR04D-$ksdv9jjL!OT<-XdTmAP&hL9 z|Lz!3lh>?JbVKHUn8#V)<(e;?iH-l?L!E|IecY#5mE*3-(XE}m2b4!k78QecWN!3G z)|~;48DJiA-iY@Bz=A&GdWfWP(ib?Rrir9Q2eK$>mvIR9brs87MO75{kGpb|tb_>FeoeVT09^3Y$!C!q8OgexFSzP&P$|(l0}?l zboxyOf56}m8GM7m5K-MnmKdXwyZ!vFP4%3iA1<#H4E9OnH9#bTgezTZ!``$oAM-4s zgCjA?iCEd8f{W^65iV&31!l~aoh`($;f#H+4Opd{m=L=*EYfMr<}RMonokaWoqt}? zwU|9Wi=K$hE;3I%k*OOeXE*KG9=G*ZksofWc}uH)P)%9O*F*-ez z6%#g=dY-z;Q&*l+`_gxF$BokCL!q#_Lxq#j}K z(%jK(wuh8dT#5KIq=&4By;~?wYUb5(XD&kIEbb3x?`}oO$aeYhhqHI}*YT>p#6YM_ zNu@>dIg(06!q#Qvn&DSb!_l4Pc#A|GKm_A{<96;oR3I7~eaYa#Nw9&h^Ag*`4%L&P z^Yn@1ut5p{zLy$sWKoIxJ90nrxF?fw+eO_N#Nk+j^*32nYFVU=;Rw&hV;C8{Am z1<>TrSjk9xmc#a0)w?8*8>W)W&XlAhhx;_{%xMNu@ML6!|CMOY91A+t zKt7-wIW{MgvAH$=*bV}AosmI#Aa&??&1EPd1x0*hVJ_(CI*;TBwfZ@kvh`bhIBF12 zkoCqqEU6j0WUp--Am9{udmAgcoxu?X9RD5}O9g{t@F?pfql&f}6GSstqS$vJJ08qM z2=-eFq5T$oV?(zkih3Jfs5pzIpZ{RxbES8VE&lkWx4%Dg_rD%n{Lb8Kt#|+D7sv4M zrg`tP?>zO>V~ej|yY1fMU;pIT;!l4f*Ntsqz)9Sw!Bq4EFP^Oa98~DPU@!zPzQEp7 z>SP78J^5mCG+>bZBBK9@2)|g`v}9c?JHDH?vdqVTJeGUG_CJXdbSgT)0yq{o`R$e? zC5llAn0Qj6*i!-q;KtY{zQ;Xh>o*ZEu?DolS?7BuwQKL;x<<{rg%tv3G zmfDS%BjpL*mzs^24c~vjGOoY+Vl{NV=1bL)^8GAFV;ZF`7zGnB6nNx+3bCWOkEpfG z_(YgB8muUOPsQU~w5MGokT+|g`4-W=m+)0)r+u{Nx|g?#;qfR$c^KQBvtik31(T4NKm9@c0lqZuv(MaO0vVzp`C)%R9=G aN9rxlHCywJp!Rt8pbvEnG#!f;Sho;=8x~F=o zt9njV&uyzs7AM46@FsvFEgL{Rpp`IhL_$cgLWn0GctbpJB&3yi*c}N8BqTt7*zovX+!#s+N*rx|WuqQ8Q$isbyrCtz~7HtL0>vujOS}s1;;b ztQBQARvVLHsaBHVcx_yU}!%u$uf^H^`cUp-lp$*p{k1Xvm2ouoc{)**-J~@w%Dt_?;@2!oO8=-+0}$d zp6^&KTV!PLJh#?{(HDI$Y-rTh@X8=Ga3^pFe+?00nMC*tCan0`a+Ct6BhkCf+?x1abx4HIphC8=Ihqv^hp=qJPEAQdII)P9N?gY^5#`;>2PVX~l}LFz*H^F0iMp;#EAMHu z${SdxD;iTU2=&5ErK$;S@ez>>LW_qtwVL7eI`zhq#f2W)>meUQL$|Ewp1S;cVDsRz z6F6=E@^gNtx$*k>`kLc0e=VqYeI8m@Tze)4%l7Q`EdBwFF%Xl-r>9{T=W|pSXT6co#xJWo&$O&_V5%QiLZgg!<6DJa$y~a}% z)8$&!+rHz8q+F~b;o2UbpiIr{iX^S!4Uw|CUE5=F`Wnk38T!1jB(hCTFW>8QH>w$7 zG^}pu_+B6~@;&4S&Cr`9h zZzTvd(0&2A1(pLs{7x^~b`A!3?WJW1lF0fH?zxO|)mDF$O5Vff)(sXhay_!NI~zZKE^hfi6-Jv+&(Hc)~~#Y zRmjFJD!LLL_(wIMPa)Vek8!hI-N~YIO?tFcHIKE;wW`?w7O8w+?7oQl z|1k&ABJ9q{rK)Tx1C`<9{kFEEU8JEVdqA)!G03FE7@$xTAU;8vP<%mv;wFJXv=4Xh38{_-v3$Rnp1)w*?}+3a-fmbtKDE;LOE=O@ht z*KfJ@s_mL9J;x4_Lyd7L3GFc(%E&?pVlb*`+RJblVsXrQ}CpD@jZ zuCrcmc0zNbw|eU2BVDt4P$W9mI;Xkh1Y7(Fgh++97xvD4bJxcF2))yD>rfwz)kUh|TAgm46EccK!wR_>$5u~@w9{GRP|>1t{_AhP zcK*@}^_O3L;idXFUw`J6^OxQdnjZ+gYlTakuvw&;!#%vKFs@jEEn}%*BM9w|Fgn(< zjclk48=qKChseCJ-ms-6sZNQkRD*SDjTluk{stm5TJZc~>O=3?mM0QH$oMgt_PD<$ z4EfcY-2l4H4*7;Kut(~$XqsYF{5Ta)$&VBWwJtL9yVNl#k!U*5crp!=(8UN6jSj<5 zcg81F`_A<@cb|oZo`P2)X2#~#c{QtLNl|Gf$UKx*RxN2V)GBIGE8^DV^V6!K9Y*<* zI;Eb(J*VoLfs{_U?`S{O|BXjzv_Z^q=bA)hq_d7noz+&QZi3c=Zb~rlM%?&8H)TQv zxul!+%gu)RP&ehG+#o5n4`S^U^hKKKqyt$J+8_sQ1I;(oMZ3A+p+j9{kfs`BqRfu| z0yEWa(@kh&q?`7ZCf$@BmEO@!xuI^#WBkyJQ9jCzbQ82EOU){Mkaw1USBdgcE9F^% z^b|A`bRm%cDl}6oEVPS*F)2$@j}&DbnrWm*3WG8m+oMMc?Fs0K5*wF#q!Q(C(<8Dy zmGw2>Z}gq9*bg8rdFMv^6ldTu(LY|lW?O~iU`<8^zX+3-E)NaD8`tDMKN`uHglPmmy<&S4HK zF?0Pzik%_x6%uD5?x`a9IV$oD38G*8MG_M1M~KhAif13u^Yaj!N96upYZ;~J|<-bJYEfNxVYZUus5|>F3 zaPqfFSR}5HAUxy^5~NwUO`=JnMdDXT)JaG*|5b`DK~yKXgFqM(rRomj%aldH$6XQ~ z5`=raOTs6yLV}YZgyR8;O%frA9*I>FYb4f5Y>@aCiQ8vbjNdh<)NAlE#5|yE2goV_ zr8|MlAVv-_)70A`EVT!~vbwJ7pYf|`u3oksNJ#7~O3^2g?nq?I{8t)B>R6pc|&GYb0pL{ifd8(U5SXiCfKIRI7!vV1`4j{|7Nmv#Qh%GIbi zn21V5JY~GH1(=XfQ2|s;Zu}L1qAY=;+|Ns($n?us%PU{QTOyP|fw%3c{BNT?U?KyU zDBt>?-02b}e9zt34<(M9%r`whG->t70^D{%cbEw^Zgbmf*n6PkxC9j|lm!E@1v7LY zV6irR&@IdJI~Mw3c4#*UYfStRG`XOJrc}}UgAb8jK7yPt3^YHa<|NR3`L0b-53J^1 zyM32Ua@r~qT||0F>G-2~1jkT9vJ=VnM0%xXxlXvTO}lgZ1Bf4g|M%$nr>Cn^`@+zE zkYl3Weo#WJ4UPrJ7ul8_)+K@CG4MQyDj-k2js!B1i;ev{d28Twac6c zN{LI205k^Ed%@_+$6!ur$(%6Kdorgz?=fOdGhj~1JH=U=G{i9W~Wb_h?dE;%Ej^Lg;+F5dGlHTfQi-yk7r&>akD_6tML z=Vm|zM4W&xW>kB~&(97z{I{r`ds3e7QtthE&&<6h`5#d?ATD};)H$q?)fVGUY`ek6 z1N!XRZbRxr^z}CT&BTKd_2KZj+?U+wG8f+deTfmCgZjW%59%|t&h}EDcKR|bz<#!$ z>*wLoGhm-3q8x0gLKA+&ej-W?vRg_&;VAtgoixlUgWML`PTw8w@)Ep;d{PJ9#iF$WdU#7RjmcV;b~J7t85 z(ECe3?;ilY4@o+Eh;raLMf&QYm4A*yNoVJoBJJM42#>G>gQIw#`7I6o93(A92S)PH z=z$CSqB)U`*rb)_>mx{#@rLW)pB_e(b8wA1J(Ekwn=C}H0(?WT6+_X9O z`BQj;K_;7SFIcKR{5ez>&a#%Qt}WCT#e~J^=n5WUYqf<#KHq(e!PnRAi6p$7ogmgm zWXpp@2lYmeS+QR<4f~|iga6c&dGM&>K+fKeea8oHa4y1*p%XR!kgP#|O;`l{cPNz# zdd;S@&fiCZUne2;;bDq>hYIL42O|51<@W3s=&dC!7s)QClfI8pq-sivUdzXeAoc=_ zdBKmHs8 zsXi*15@sn_l>T#a8i|A-bYY!s1ae&oz2*8>;P58@y7cwytxcyZ^oEbqWu2^Mp|=Cy z6RED#SaxlZBDeeLQ*v2}^pX|OiW3I6aS#yTQy5nq&k7ojBhqMyHX(l={mR8BG}vRX zesJpA<1Xqg4y%b&oUTpCqu@o)TWt0`X_{LuPL4es4v*6O_t8RaT$WtL$-3_?%7bwn z@X+}<&VX0rbR|YD?lF%Q;q(8Q`pRxP|6i#mQ^}p&{C}rT+?IMAhU#bk1H~uck^lez literal 0 HcmV?d00001 diff --git a/utils/__pycache__/video_op.cpython-39.pyc b/utils/__pycache__/video_op.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf39bf1d80d9280ef78aa2a38ef1ae851b1b791a GIT binary patch literal 9445 zcmd5?U2h!Md7iJ`o&8=em*289mMvTBm=rC`it8$}{B608s@SqsZ>4TWduErjoSj*o znWac(#t14qmfI?3(!gjMG*}iX5ap`qb&#w4fTCyvbkK{Sy@(0)qCkPRAPS_`b zWy)|8r07!f&N*{F=FFKh&-1+Rk(tkDB>X7K|`UwhQBnc|K0uT(ScT*AflS@pvp}8u(sDKbo!tjwf;WM`1$g`eY>E7SnWFnv&jEW~8^V zN>>ynVF>c2yHZt&m4(GP;RhxU?kY9i?RDynC6mWmV66qbh;DYi@WM0K-tsN(U$cF? z?Ol7zwV!kCju}|jdV%fu*H&$2d3CS5&|Qzy=USF)t#$c%dcOa_A#FT}OLgk6j#e#e z;hf_&OvgXJFsjSxL8^*42@@x7IA&uRG?wE;;PJ*% zoN02eQ^$04*Q@DRZVv8Tk>p3`RsA_t zlcb+L%TM7X%r;)Ow5=h#P2bKoO;AHXF%tBf#F54E0uKKNSU`ak$f43#7*_h=Yhvmx zB}&nE6KbK-PS7{iPcoGyaNut%5oJVa;y$I>AQh@%620WYTwg~EnNsbvxQ~MSs4R&u zBp1pkMLxoU?PtR5mfX*VnL&PEZG5+=O?7e%ZMPJezWh{0isT@Vl#V;mGwkm{L4L2c zB5EtBt%%xrjK2U($cMR>N?$(u&!aTYviQn#z=*tnT`9~H1|U~#$zi$zOxcvLcca3l zys;1z!~89bXIa^l`PV~I`N)D7DLIJJwRR zKz8omqjg!d9&cCr69L)DP{K@&i}VzwC#5YVDhA{IX;uk~-D|pdF7{UI@G@_G29Dw zHk5GxWQ5W6r^E5VI1Y}B6{E^ktj0}c<@;A9_i0TE<~Eh}@5xd$Axj%a!wD*X2Xhon zygPwAEbisp%;4&xw21j$S!6|k*WAW2BV*u4AB@1CXN>39URZm1?bWrbqCmjNHi6?~+1MuL zxhwPE#%aSi$&Ge(w}|RB>C#fwIN3ICR*eRbO6G^6^P6bz&%(3_pUYw)OIuQ3W_YRJ zR93V%X{d=F@a<^~G9eHM808bh-BoxAO@V%$)&M#Ppa;;AIY*e&@H*Y|SLh;r+ZzkM zxoQpOr_QY=n$=C02`t`+!lQ}A+uCMmi)5#JW+#2ZuQ$2bvHYDB?-=vEXY4-QhRr{k2}xyeDnapmHdufB2d%1iavF2D3j{VQ)h_xi;v z@5GAd$6D76miRGhl43S@@vO0a!}KkYO8V=5V0B`>V=h}L2Ax=V!D2e%^h;|EOR!9J zGR_E!tW$5Kn3{uk5T{4Wou8z8XdTORW7Q8BuZq0Oy_>Nv-g>j^gZHd}ug5yRkh-Xv zq7+r1r|L=ZlKfb%$7%6i>KIh4Hf^w>$U_EnF#@&GVHoP}_+r_*clFJF#zMoL{MTTl z;$4-enM~(Qtc#I zB*iphMwS3aWWhn;!y!ZMl|o?-8A>Pq>Zil>E~`OvDhC(~eT*1ue{Eu{%&7J*V`Ybo zmBaYKmtii7vI_V@u*wWrJUKHP&ks1}O@M@P!wZ&-V96rJ5g#lw1n(FC zRB_#QyS>0@crLSrR9qO)&{?cZ{uCLwz7jt}=C{ZYEAew=zC`9cnP*{sRodXsQ4b{L z_{(GjHH`Ryzlv*{bX^d`3Guz!r&mXSaws40*Qv#qVXp8?l)OxaK4$(EGH;L}N~_Ka z0OoJe1wvo`HW>n6ewEDckr6=s4y9^jexD2hEx$(Q56PHh2tWA^G7U0BN8BRQBy*k2 zACM8~U#CzIvJl#K&D4# zmCQ{tYh*Ua9Gu~J>{D|}H}qeGk&63)uOig^0pJ7N6##D~Svd%KL%=%;aO4sCY3S73 z@JF7^2=u;>^(0_T4axoeH0bG~RJ%`U8epaoQ1uhSO(lVyWZDR>5<+0709PdCGhv2l z0GkX;F&$u&5nz)E0AUcwS(XM2r2#rg0-ftSQGVmcVU{V&3jY^~+l&ymX-IUmq<$c} z(*TH(9RUQ~?>2yjWrxz(A(%od$a0e8ceJPg;f;IZF0Vr57Nb&Fq&rPZLWn;j#BGTc zmlFV{(z1L302e}T&q>i3fN5-5gIs@R<$uCLe>^OabQS5!7R0K67J21cSEZGYAy^?x z*S`U1DFa%LgykqFz$6`&-z~3vyg!Yn$29_$m5%|+X~2ri`(X~Sk_N1lcm7-;gFqF} zb=D6cM>9J3JM0MY6?vEMlsa1ZFnh_es zsi9=#r;$ToNfqixsL43B(lZ@9Sl^}(dHX|{pZ~>o==V?0RwoaIt-~PdL9{yzdPo%_ zs9}3?rey_nArSchsE(rv#96N+fG^Gt0u6PF_8^Aic!;S8+!`0S^~<3Z0acu)aE-YA z9%jYK;Sff!MbQtu^+Pdh^6=^YYxM3DUr7Z5u3dcwkOEMVI#X~ag;oY+733p;DN=F+ zb@niU)i==g09Z}n!U(J`>?*YXcVGpN{@YO=9{t-x!20CZ0jy2gMV#Rtt(Ch+Ye9QK zYvsej$crbfRbT~*9KwsQ3<2);ApqA6NMjMS5x|X&0B&jsaAg5-UGkd;<3PWp;(_3tgQNLU#GDKy#$NJSbzit_;= zRrt9PkWTS;v8MPv0_*<_v;K{4F2vO1;v))9uP)gmIrVu^>r?9L7Ipa^nfJ*Ev2~Bm zn)$T>>vQWYf-ioXjL>cSB-SqvRs2t=pZf}{J5>7F>T3Fall+fp>tBF59IMD*B&>Q` z8TgWfR)wKQNj~U4%=!p@9Pla)To_^2x7DG*x}Sv7hF<02@(a5T5Lnn7fWR91jQa&v zJ9Q0yX_SewQ4Y~Z9bTgvX5r!Gn}~5nYN+-zTT-OjQdFRwjTxz*-5Pjw4-avngcv8q zQ%Pq19mbLmOY{d1ap13(M6acNy`ps3CMo37^Za?IC@* ze?00XH56Sm`UCfIo1;3qsM81H4DO+^)ZiX0Jq2kfaBH^=;&Mm3D>T z(`ym(EDzB;K{yPFJ0T?QL@*XkK#*dN;VZroN4XrpyYn}gMQMbl z6wRGP`|?oG{sVTWCaLf1v(Z#Ah4|?dVg=LTWPc`{COL|TX$>-b+`$%OG=o_8Oo$kC z`$#k!&I0vm6rThSw2$J>I;-4KK3a_C!b#EJk>FV02#=t@2g14jgW&_i{(g8Bqtozr zi=xftCvXq{pBEu`>F*sJ?s@b*fxcCKG@PX8X&(m*P9lvpLEA`diYcK!C$VX~6OGN> zNure=Jj7;i$&k5HbW~uzx&r4<_~j38s_PEK^D*I!;_A>Z z|4De1Nf6K~-waN$x&9+~&ZAoj?tg^1QEX&AaOqGEAYXdmZVT7?31kVv@cZ{?hY9Hd zLb*1Da`%ONDT#}_!CVUE+QLu&4eH;gz*++?c}>8ah%D3*mJIB!W7YS93U&Ad^_JbN zFL~VF!2Xlz)Stw5;9=f7{~bK*o<_!lXqORBtjhdrDA>>z&`(W$<>Jd%>`#Dg)nu$R zyK$0S1~aJThiY{gA{@jNO>D*KtS9uO8*7f|bvI7F<@&vD*W=`!V6zU+2SU{g2DS_h z3LzUFH|SOi8^`I{hEbEiGrk%bO%wIm18!kC}rQZwg*gr-xWD34F0B zfQ?7%Fr`1(U__u4-W2T|@`ptWiqXRR;NPTN((g5!_8R{>GW-KFf<+&q6nP%l&!vcS z=55pISufF33l|_xbUE#4{u!!N4dD*da)U)Mh-k;T4kF`Defg%zTYfF;dp+K;p#59^ zzSu^GZZnE)#PW?;Zp5-1%S*9*vo^863)z{k zLI$}U;?blFD7T1|4qrz>S{yriTF>o9p-GXT#Q#lPj8)I?!uwnI z#afEBmgC((z@GL8#9oEg+OWH^*6^@vu93eUYi-|i<7C%vEIU@5q-}$-^x<&|HHifmEAaQ=V_`?R#~rj=7&eQOgM4jVY;!NT z?n1NYilBh$U{~D5#`h@Czk?oXW1{9N_WnJ0L2USAdy4k`vA?}KNK^($#l42*`?lL! zpzn_ZQ?TCYzkS3>%XYo4KipNheDQVhEs1@4v5P<0rAfJ7y~Rx?!uulJCEQ3(x=c*; zzKENMy@@?@{T$YH&#}&9El55^?8p%%$@dw5c=1%H3`>6%RWBuzNjWKAHeUFD=DxHq X|6jSMlZjF3f9Fme$UT)*%D?zu5TO27 literal 0 HcmV?d00001 diff --git a/utils/assign_cfg.py b/utils/assign_cfg.py new file mode 100644 index 0000000..24ce8ef --- /dev/null +++ b/utils/assign_cfg.py @@ -0,0 +1,78 @@ +import os, yaml +from copy import deepcopy, copy + + +# def get prior and ldm config +def assign_prior_mudule_cfg(cfg): + ''' + ''' + # + prior_cfg = deepcopy(cfg) + vldm_cfg = deepcopy(cfg) + + with open(cfg.prior_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + prior_cfg[k].update(v) + else: + prior_cfg[k] = v + + with open(cfg.vldm_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vldm_cfg[k].update(v) + else: + vldm_cfg[k] = v + + return prior_cfg, vldm_cfg + + +# def get prior and ldm config +def assign_vldm_vsr_mudule_cfg(cfg): + ''' + ''' + # + vldm_cfg = deepcopy(cfg) + vsr_cfg = deepcopy(cfg) + + with open(cfg.vldm_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vldm_cfg[k].update(v) + else: + vldm_cfg[k] = v + + with open(cfg.vsr_cfg, 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vsr_cfg[k].update(v) + else: + vsr_cfg[k] = v + + return vldm_cfg, vsr_cfg + + +# def get prior and ldm config +def assign_signle_cfg(cfg, _cfg_update, tname): + ''' + ''' + # + vldm_cfg = deepcopy(cfg) + if os.path.exists(_cfg_update[tname]): + with open(_cfg_update[tname], 'r') as f: + _cfg_update = yaml.load(f.read(), Loader=yaml.SafeLoader) + # _cfg_update = _cfg_update.cfg_dict + for k, v in _cfg_update.items(): + if isinstance(v, dict) and k in cfg: + vldm_cfg[k].update(v) + else: + vldm_cfg[k] = v + return vldm_cfg \ No newline at end of file diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..1a077b7 --- /dev/null +++ b/utils/config.py @@ -0,0 +1,243 @@ +import os +import yaml +import json +import copy +import argparse + +from ..utils import logging +# logger = logging.get_logger(__name__) + +class Config(object): + def __init__(self, load=True, cfg_dict=None, cfg_level=None): + self._level = "cfg" + ("." + cfg_level if cfg_level is not None else "") + + current_directory = os.path.dirname(os.path.abspath(__file__)) + parent_directory = os.path.dirname(current_directory) + self.config_file_loc = os.path.join(parent_directory, 'configs/UniAnimate_infer.yaml') + + if load: + self.args = self._parse_args() + # logger.info("Loading config from {}.".format(self.args.cfg_file)) + self.need_initialization = True + cfg_base = self._load_yaml(self.args) # self._initialize_cfg() + cfg_dict = self._load_yaml(self.args) + cfg_dict = self._merge_cfg_from_base(cfg_base, cfg_dict) + cfg_dict = self._update_from_args(cfg_dict) + self.cfg_dict = cfg_dict + self._update_dict(cfg_dict) + + def _parse_args(self): + parser = argparse.ArgumentParser( + description="Argparser for configuring the codebase" + ) + parser.add_argument( + "--cfg", + dest="cfg_file", + help="Path to the configuration file", + default= self.config_file_loc + ) + parser.add_argument( + "--init_method", + help="Initialization method, includes TCP or shared file-system", + default="tcp://localhost:9999", + type=str, + ) + parser.add_argument( + '--debug', + action='store_true', + default=False, + help='Output debug information' + ) + parser.add_argument( + '--windows-standalone-build', + action='store_true', + default=False, + help='Indicates if the build is a standalone build for Windows' + ) + parser.add_argument( + "opts", + help="Other configurations", + default=None, + nargs=argparse.REMAINDER + ) + return parser.parse_args() + + + def _path_join(self, path_list): + path = "" + for p in path_list: + path+= p + '/' + return path[:-1] + + def _update_from_args(self, cfg_dict): + args = self.args + for var in vars(args): + cfg_dict[var] = getattr(args, var) + return cfg_dict + + def _initialize_cfg(self): + if self.need_initialization: + self.need_initialization = False + if os.path.exists('./configs/base.yaml'): + with open("./configs/base.yaml", 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + else: + with open(os.path.realpath(__file__).split('/')[-3] + "/configs/base.yaml", 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + return cfg + + def _load_yaml(self, args, file_name=""): + assert args.cfg_file is not None + if not file_name == "": # reading from base file + with open(file_name, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + else: + if os.getcwd().split("/")[-1] == args.cfg_file.split("/")[0]: + args.cfg_file = args.cfg_file.replace(os.getcwd().split("/")[-1], "./") + with open(args.cfg_file, 'r') as f: + cfg = yaml.load(f.read(), Loader=yaml.SafeLoader) + file_name = args.cfg_file + + if "_BASE_RUN" not in cfg.keys() and "_BASE_MODEL" not in cfg.keys() and "_BASE" not in cfg.keys(): + # return cfg if the base file is being accessed + cfg = self._merge_cfg_from_command_update(args, cfg) + return cfg + + if "_BASE" in cfg.keys(): + if cfg["_BASE"][1] == '.': + prev_count = cfg["_BASE"].count('..') + cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE"].count('..'))] + cfg["_BASE"].split('/')[prev_count:]) + else: + cfg_base_file = cfg["_BASE"].replace( + "./", + args.cfg_file.replace(args.cfg_file.split('/')[-1], "") + ) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg) + else: + if "_BASE_RUN" in cfg.keys(): + if cfg["_BASE_RUN"][1] == '.': + prev_count = cfg["_BASE_RUN"].count('..') + cfg_base_file = self._path_join(file_name.split('/')[:(-1-prev_count)] + cfg["_BASE_RUN"].split('/')[prev_count:]) + else: + cfg_base_file = cfg["_BASE_RUN"].replace( + "./", + args.cfg_file.replace(args.cfg_file.split('/')[-1], "") + ) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg, preserve_base=True) + if "_BASE_MODEL" in cfg.keys(): + if cfg["_BASE_MODEL"][1] == '.': + prev_count = cfg["_BASE_MODEL"].count('..') + cfg_base_file = self._path_join(file_name.split('/')[:(-1-cfg["_BASE_MODEL"].count('..'))] + cfg["_BASE_MODEL"].split('/')[prev_count:]) + else: + cfg_base_file = cfg["_BASE_MODEL"].replace( + "./", + args.cfg_file.replace(args.cfg_file.split('/')[-1], "") + ) + cfg_base = self._load_yaml(args, cfg_base_file) + cfg = self._merge_cfg_from_base(cfg_base, cfg) + cfg = self._merge_cfg_from_command(args, cfg) + return cfg + + def _merge_cfg_from_base(self, cfg_base, cfg_new, preserve_base=False): + for k,v in cfg_new.items(): + if k in cfg_base.keys(): + if isinstance(v, dict): + self._merge_cfg_from_base(cfg_base[k], v) + else: + cfg_base[k] = v + else: + if "BASE" not in k or preserve_base: + cfg_base[k] = v + return cfg_base + + def _merge_cfg_from_command_update(self, args, cfg): + if len(args.opts) == 0: + return cfg + + assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( + args.opts, len(args.opts) + ) + keys = args.opts[0::2] + vals = args.opts[1::2] + + for key, val in zip(keys, vals): + cfg[key] = val + + return cfg + + def _merge_cfg_from_command(self, args, cfg): + assert len(args.opts) % 2 == 0, 'Override list {} has odd length: {}.'.format( + args.opts, len(args.opts) + ) + keys = args.opts[0::2] + vals = args.opts[1::2] + + # maximum supported depth 3 + for idx, key in enumerate(keys): + key_split = key.split('.') + assert len(key_split) <= 4, 'Key depth error. \nMaximum depth: 3\n Get depth: {}'.format( + len(key_split) + ) + assert key_split[0] in cfg.keys(), 'Non-existant key: {}.'.format( + key_split[0] + ) + if len(key_split) == 2: + assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( + key + ) + elif len(key_split) == 3: + assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( + key + ) + assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( + key + ) + elif len(key_split) == 4: + assert key_split[1] in cfg[key_split[0]].keys(), 'Non-existant key: {}.'.format( + key + ) + assert key_split[2] in cfg[key_split[0]][key_split[1]].keys(), 'Non-existant key: {}.'.format( + key + ) + assert key_split[3] in cfg[key_split[0]][key_split[1]][key_split[2]].keys(), 'Non-existant key: {}.'.format( + key + ) + if len(key_split) == 1: + cfg[key_split[0]] = vals[idx] + elif len(key_split) == 2: + cfg[key_split[0]][key_split[1]] = vals[idx] + elif len(key_split) == 3: + cfg[key_split[0]][key_split[1]][key_split[2]] = vals[idx] + elif len(key_split) == 4: + cfg[key_split[0]][key_split[1]][key_split[2]][key_split[3]] = vals[idx] + return cfg + + def _update_dict(self, cfg_dict): + def recur(key, elem): + if type(elem) is dict: + return key, Config(load=False, cfg_dict=elem, cfg_level=key) + else: + if type(elem) is str and elem[1:3]=="e-": + elem = float(elem) + return key, elem + dic = dict(recur(k, v) for k, v in cfg_dict.items()) + self.__dict__.update(dic) + + def get_args(self): + return self.args + + def __repr__(self): + return "{}\n".format(self.dump()) + + def dump(self): + return json.dumps(self.cfg_dict, indent=2) + + def deep_copy(self): + return copy.deepcopy(self) + +# if __name__ == '__main__': +# # debug +# cfg = Config(load=True) +# print(cfg.DATA) \ No newline at end of file diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000..c5765f6 --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +import torch +import torch.nn.functional as F +import torch.distributed as dist +import functools +import pickle +import numpy as np +from collections import OrderedDict +from torch.autograd import Function + +__all__ = ['is_dist_initialized', + 'get_world_size', + 'get_rank', + 'new_group', + 'destroy_process_group', + 'barrier', + 'broadcast', + 'all_reduce', + 'reduce', + 'gather', + 'all_gather', + 'reduce_dict', + 'get_global_gloo_group', + 'generalized_all_gather', + 'generalized_gather', + 'scatter', + 'reduce_scatter', + 'send', + 'recv', + 'isend', + 'irecv', + 'shared_random_seed', + 'diff_all_gather', + 'diff_all_reduce', + 'diff_scatter', + 'diff_copy', + 'spherical_kmeans', + 'sinkhorn'] + +#-------------------------------- Distributed operations --------------------------------# + +def is_dist_initialized(): + return dist.is_available() and dist.is_initialized() + +def get_world_size(group=None): + return dist.get_world_size(group) if is_dist_initialized() else 1 + +def get_rank(group=None): + return dist.get_rank(group) if is_dist_initialized() else 0 + +def new_group(ranks=None, **kwargs): + if is_dist_initialized(): + return dist.new_group(ranks, **kwargs) + return None + +def destroy_process_group(): + if is_dist_initialized(): + dist.destroy_process_group() + +def barrier(group=None, **kwargs): + if get_world_size(group) > 1: + dist.barrier(group, **kwargs) + +def broadcast(tensor, src, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.broadcast(tensor, src, group, **kwargs) + +def all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.all_reduce(tensor, op, group, **kwargs) + +def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.reduce(tensor, dst, op, group, **kwargs) + +def gather(tensor, dst=0, group=None, **kwargs): + rank = get_rank() # global rank + world_size = get_world_size(group) + if world_size == 1: + return [tensor] + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] if rank == dst else None + dist.gather(tensor, tensor_list, dst, group, **kwargs) + return tensor_list + +def all_gather(tensor, uniform_size=True, group=None, **kwargs): + world_size = get_world_size(group) + if world_size == 1: + return [tensor] + assert tensor.is_contiguous(), 'ops.all_gather requires the tensor to be contiguous()' + + if uniform_size: + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group, **kwargs) + return tensor_list + else: + # collect tensor shapes across GPUs + shape = tuple(tensor.shape) + shape_list = generalized_all_gather(shape, group) + + # flatten the tensor + tensor = tensor.reshape(-1) + size = int(np.prod(shape)) + size_list = [int(np.prod(u)) for u in shape_list] + max_size = max(size_list) + + # pad to maximum size + if size != max_size: + padding = tensor.new_zeros(max_size - size) + tensor = torch.cat([tensor, padding], dim=0) + + # all_gather + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list, tensor, group, **kwargs) + + # reshape tensors + tensor_list = [t[:n].view(s) for t, n, s in zip( + tensor_list, size_list, shape_list)] + return tensor_list + +@torch.no_grad() +def reduce_dict(input_dict, group=None, reduction='mean', **kwargs): + assert reduction in ['mean', 'sum'] + world_size = get_world_size(group) + if world_size == 1: + return input_dict + + # ensure that the orders of keys are consistent across processes + if isinstance(input_dict, OrderedDict): + keys = list(input_dict.keys) + else: + keys = sorted(input_dict.keys()) + vals = [input_dict[key] for key in keys] + vals = torch.stack(vals, dim=0) + dist.reduce(vals, dst=0, group=group, **kwargs) + if dist.get_rank(group) == 0 and reduction == 'mean': + vals /= world_size + dist.broadcast(vals, src=0, group=group, **kwargs) + reduced_dict = type(input_dict)([ + (key, val) for key, val in zip(keys, vals)]) + return reduced_dict + +@functools.lru_cache() +def get_global_gloo_group(): + backend = dist.get_backend() + assert backend in ['gloo', 'nccl'] + if backend == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + +def _serialize_to_tensor(data, group): + backend = dist.get_backend(group) + assert backend in ['gloo', 'nccl'] + device = torch.device('cpu' if backend == 'gloo' else 'cuda') + + buffer = pickle.dumps(data) + if len(buffer) > 1024 ** 3: + logger = logging.getLogger(__name__) + logger.warning( + 'Rank {} trying to all-gather {:.2f} GB of data on device' + '{}'.format(get_rank(), len(buffer) / (1024 ** 3), device)) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device=device) + return tensor + +def _pad_to_largest_tensor(tensor, group): + world_size = dist.get_world_size(group=group) + assert world_size >= 1, \ + 'gather/all_gather must be called from ranks within' \ + 'the give group!' + local_size = torch.tensor( + [tensor.numel()], dtype=torch.int64, device=tensor.device) + size_list = [torch.zeros( + [1], dtype=torch.int64, device=tensor.device) + for _ in range(world_size)] + + # gather tensors and compute the maximum size + dist.all_gather(size_list, local_size, group=group) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # pad tensors to the same size + if local_size != max_size: + padding = torch.zeros( + (max_size - local_size, ), + dtype=torch.uint8, device=tensor.device) + tensor = torch.cat((tensor, padding), dim=0) + return size_list, tensor + +def generalized_all_gather(data, group=None): + if get_world_size(group) == 1: + return [data] + if group is None: + group = get_global_gloo_group() + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + max_size = max(size_list) + + # receiving tensors from all ranks + tensor_list = [torch.empty( + (max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list] + dist.all_gather(tensor_list, tensor, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + +def generalized_gather(data, dst=0, group=None): + world_size = get_world_size(group) + if world_size == 1: + return [data] + if group is None: + group = get_global_gloo_group() + rank = dist.get_rank() # global rank + + tensor = _serialize_to_tensor(data, group) + size_list, tensor = _pad_to_largest_tensor(tensor, group) + + # receiving tensors from all ranks to dst + if rank == dst: + max_size = max(size_list) + tensor_list = [torch.empty( + (max_size, ), dtype=torch.uint8, device=tensor.device) + for _ in size_list] + dist.gather(tensor, tensor_list, dst=dst, group=group) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + return data_list + else: + dist.gather(tensor, [], dst=dst, group=group) + return [] + +def scatter(data, scatter_list=None, src=0, group=None, **kwargs): + r"""NOTE: only supports CPU tensor communication. + """ + if get_world_size(group) > 1: + return dist.scatter(data, scatter_list, src, group, **kwargs) + +def reduce_scatter(output, input_list, op=dist.ReduceOp.SUM, group=None, **kwargs): + if get_world_size(group) > 1: + return dist.reduce_scatter(output, input_list, op, group, **kwargs) + +def send(tensor, dst, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.send requires the tensor to be contiguous()' + return dist.send(tensor, dst, group, **kwargs) + +def recv(tensor, src=None, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.recv requires the tensor to be contiguous()' + return dist.recv(tensor, src, group, **kwargs) + +def isend(tensor, dst, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.isend requires the tensor to be contiguous()' + return dist.isend(tensor, dst, group, **kwargs) + +def irecv(tensor, src=None, group=None, **kwargs): + if get_world_size(group) > 1: + assert tensor.is_contiguous(), 'ops.irecv requires the tensor to be contiguous()' + return dist.irecv(tensor, src, group, **kwargs) + +def shared_random_seed(group=None): + seed = np.random.randint(2 ** 31) + all_seeds = generalized_all_gather(seed, group) + return all_seeds[0] + +#-------------------------------- Differentiable operations --------------------------------# + +def _all_gather(x): + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: + return x + rank = dist.get_rank() + world_size = dist.get_world_size() + tensors = [torch.empty_like(x) for _ in range(world_size)] + tensors[rank] = x + dist.all_gather(tensors, x) + return torch.cat(tensors, dim=0).contiguous() + +def _all_reduce(x): + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: + return x + dist.all_reduce(x) + return x + +def _split(x): + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() == 1: + return x + rank = dist.get_rank() + world_size = dist.get_world_size() + return x.chunk(world_size, dim=0)[rank].contiguous() + +class DiffAllGather(Function): + r"""Differentiable all-gather. + """ + @staticmethod + def symbolic(graph, input): + return _all_gather(input) + + @staticmethod + def forward(ctx, input): + return _all_gather(input) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output) + +class DiffAllReduce(Function): + r"""Differentiable all-reducd. + """ + @staticmethod + def symbolic(graph, input): + return _all_reduce(input) + + @staticmethod + def forward(ctx, input): + return _all_reduce(input) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + +class DiffScatter(Function): + r"""Differentiable scatter. + """ + @staticmethod + def symbolic(graph, input): + return _split(input) + + @staticmethod + def symbolic(ctx, input): + return _split(input) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output) + +class DiffCopy(Function): + r"""Differentiable copy that reduces all gradients during backward. + """ + @staticmethod + def symbolic(graph, input): + return input + + @staticmethod + def forward(ctx, input): + return input + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output) + +diff_all_gather = DiffAllGather.apply +diff_all_reduce = DiffAllReduce.apply +diff_scatter = DiffScatter.apply +diff_copy = DiffCopy.apply + +#-------------------------------- Distributed algorithms --------------------------------# + +@torch.no_grad() +def spherical_kmeans(feats, num_clusters, num_iters=10): + k, n, c = num_clusters, *feats.size() + ones = feats.new_ones(n, dtype=torch.long) + + # distributed settings + rank = get_rank() + world_size = get_world_size() + + # init clusters + rand_inds = torch.randperm(n)[:int(np.ceil(k / world_size))] + clusters = torch.cat(all_gather(feats[rand_inds]), dim=0)[:k] + + # variables + new_clusters = feats.new_zeros(k, c) + counts = feats.new_zeros(k, dtype=torch.long) + + # iterative Expectation-Maximization + for step in range(num_iters + 1): + # Expectation step + simmat = torch.mm(feats, clusters.t()) + scores, assigns = simmat.max(dim=1) + if step == num_iters: + break + + # Maximization step + new_clusters.zero_().scatter_add_(0, assigns.unsqueeze(1).repeat(1, c), feats) + all_reduce(new_clusters) + + counts.zero_() + counts.index_add_(0, assigns, ones) + all_reduce(counts) + + mask = (counts > 0) + clusters[mask] = new_clusters[mask] / counts[mask].view(-1, 1) + clusters = F.normalize(clusters, p=2, dim=1) + return clusters, assigns, scores + +@torch.no_grad() +def sinkhorn(Q, eps=0.5, num_iters=3): + # normalize Q + Q = torch.exp(Q / eps).t() + sum_Q = Q.sum() + all_reduce(sum_Q) + Q /= sum_Q + + # variables + n, m = Q.size() + u = Q.new_zeros(n) + r = Q.new_ones(n) / n + c = Q.new_ones(m) / (m * get_world_size()) + + # iterative update + cur_sum = Q.sum(dim=1) + all_reduce(cur_sum) + for i in range(num_iters): + u = cur_sum + Q *= (r / u).unsqueeze(1) + Q *= (c / Q.sum(dim=0)).unsqueeze(0) + cur_sum = Q.sum(dim=1) + all_reduce(cur_sum) + return (Q / Q.sum(dim=0, keepdim=True)).t().float() diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000..30b563d --- /dev/null +++ b/utils/logging.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. + +"""Logging.""" + +import builtins +import decimal +import functools +import logging +import os +import sys +from ..lib import simplejson +# from fvcore.common.file_io import PathManager + +from ..utils import distributed as du + + +def _suppress_print(): + """ + Suppresses printing from the current process. + """ + + def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): + pass + + builtins.print = print_pass + + +# @functools.lru_cache(maxsize=None) +# def _cached_log_stream(filename): +# return PathManager.open(filename, "a") + + +def setup_logging(cfg, log_file): + """ + Sets up the logging for multiple processes. Only enable the logging for the + master process, and suppress logging for the non-master processes. + """ + if du.is_master_proc(): + # Enable logging for the master process. + logging.root.handlers = [] + else: + # Suppress logging for non-master processes. + _suppress_print() + + logger = logging.getLogger() + logger.setLevel(logging.INFO) + logger.propagate = False + plain_formatter = logging.Formatter( + "[%(asctime)s][%(levelname)s] %(name)s: %(lineno)4d: %(message)s", + datefmt="%m/%d %H:%M:%S", + ) + + if du.is_master_proc(): + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + ch.setFormatter(plain_formatter) + logger.addHandler(ch) + + if log_file is not None and du.is_master_proc(du.get_world_size()): + filename = os.path.join(cfg.OUTPUT_DIR, log_file) + fh = logging.FileHandler(filename) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + +def get_logger(name): + """ + Retrieve the logger with the specified name or, if name is None, return a + logger which is the root logger of the hierarchy. + Args: + name (string): name of the logger. + """ + return logging.getLogger(name) + + +def log_json_stats(stats): + """ + Logs json stats. + Args: + stats (dict): a dictionary of statistical information to log. + """ + stats = { + k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v + for k, v in stats.items() + } + json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) + logger = get_logger(__name__) + logger.info("{:s}".format(json_stats)) diff --git a/utils/mp4_to_gif.py b/utils/mp4_to_gif.py new file mode 100644 index 0000000..3a53df9 --- /dev/null +++ b/utils/mp4_to_gif.py @@ -0,0 +1,16 @@ +import os + + + +# source_mp4_dir = "outputs/UniAnimate_infer" +# target_gif_dir = "outputs/UniAnimate_infer_gif" + +source_mp4_dir = "outputs/UniAnimate_infer_long" +target_gif_dir = "outputs/UniAnimate_infer_long_gif" + +os.makedirs(target_gif_dir, exist_ok=True) +for video in os.listdir(source_mp4_dir): + video_dir = os.path.join(source_mp4_dir, video) + gif_dir = os.path.join(target_gif_dir, video.replace(".mp4", ".gif")) + cmd = f'ffmpeg -i {video_dir} {gif_dir}' + os.system(cmd) \ No newline at end of file diff --git a/utils/multi_port.py b/utils/multi_port.py new file mode 100644 index 0000000..a5fcbb7 --- /dev/null +++ b/utils/multi_port.py @@ -0,0 +1,9 @@ +import socket +from contextlib import closing + +def find_free_port(): + """ https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number """ + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return str(s.getsockname()[1]) \ No newline at end of file diff --git a/utils/optim/__init__.py b/utils/optim/__init__.py new file mode 100644 index 0000000..0c67cda --- /dev/null +++ b/utils/optim/__init__.py @@ -0,0 +1,2 @@ +from .lr_scheduler import * +from .adafactor import * diff --git a/utils/optim/adafactor.py b/utils/optim/adafactor.py new file mode 100644 index 0000000..d38d9ac --- /dev/null +++ b/utils/optim/adafactor.py @@ -0,0 +1,230 @@ +import math +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + +__all__ = ['Adafactor'] + +class Adafactor(Optimizer): + """ + AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code: + https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that + this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and + `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + Arguments: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*): + The external learning rate. + eps (`Tuple[float, float]`, *optional*, defaults to (1e-30, 1e-3)): + Regularization constants for square gradient and parameter scale respectively + clip_threshold (`float`, *optional*, defaults 1.0): + Threshold of root mean square of final gradient update + decay_rate (`float`, *optional*, defaults to -0.8): + Coefficient used to compute running averages of square + beta1 (`float`, *optional*): + Coefficient used for computing running averages of gradient + weight_decay (`float`, *optional*, defaults to 0): + Weight decay (L2 penalty) + scale_parameter (`bool`, *optional*, defaults to `True`): + If True, learning rate is scaled by root mean square + relative_step (`bool`, *optional*, defaults to `True`): + If True, time-dependent learning rate is computed instead of external learning rate + warmup_init (`bool`, *optional*, defaults to `False`): + Time-dependent learning rate computation depends on whether warm-up initialization is being used + This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested. + Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3): + - Training without LR warmup or clip_threshold is not recommended. + - use scheduled LR warm-up to fixed LR + - use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235) + - Disable relative updates + - Use scale_parameter=False + - Additional optimizer operations like gradient clipping should not be used alongside Adafactor + Example: + ```python + Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3) + ``` + Others reported the following combination to work well: + ```python + Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + ``` + When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`] + scheduler as following: + ```python + from transformers.optimization import Adafactor, AdafactorSchedule + optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None) + lr_scheduler = AdafactorSchedule(optimizer) + trainer = Trainer(..., optimizers=(optimizer, lr_scheduler)) + ``` + Usage: + ```python + # replace AdamW with Adafactor + optimizer = Adafactor( + model.parameters(), + lr=1e-3, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + relative_step=False, + scale_parameter=False, + warmup_init=False, + ) + ```""" + + def __init__( + self, + params, + lr=None, + eps=(1e-30, 1e-3), + clip_threshold=1.0, + decay_rate=-0.8, + beta1=None, + weight_decay=0.0, + scale_parameter=True, + relative_step=True, + warmup_init=False, + ): + r"""require_version("torch>=1.5.0") # add_ with alpha + """ + if lr is not None and relative_step: + raise ValueError("Cannot combine manual `lr` and `relative_step=True` options") + if warmup_init and not relative_step: + raise ValueError("`warmup_init=True` requires `relative_step=True`") + + defaults = dict( + lr=lr, + eps=eps, + clip_threshold=clip_threshold, + decay_rate=decay_rate, + beta1=beta1, + weight_decay=weight_decay, + scale_parameter=scale_parameter, + relative_step=relative_step, + warmup_init=warmup_init, + ) + super().__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + rel_step_sz = param_group["lr"] + if param_group["relative_step"]: + min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 + rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) + param_scale = 1.0 + if param_group["scale_parameter"]: + param_scale = max(param_group["eps"][1], param_state["RMS"]) + return param_scale * rel_step_sz + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group["beta1"] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + @staticmethod + def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col): + # copy from fairseq's adafactor implementation: + # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505 + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """ + Performs a single optimization step + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = self._rms(p_data_fp32) + lr = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad**2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/utils/optim/lr_scheduler.py b/utils/optim/lr_scheduler.py new file mode 100644 index 0000000..68c6fb9 --- /dev/null +++ b/utils/optim/lr_scheduler.py @@ -0,0 +1,58 @@ +import math +from torch.optim.lr_scheduler import _LRScheduler + +__all__ = ['AnnealingLR'] + +class AnnealingLR(_LRScheduler): + + def __init__(self, optimizer, base_lr, warmup_steps, total_steps, decay_mode='cosine', min_lr=0.0, last_step=-1): + assert decay_mode in ['linear', 'cosine', 'none'] + self.optimizer = optimizer + self.base_lr = base_lr + self.warmup_steps = warmup_steps + self.total_steps = total_steps + self.decay_mode = decay_mode + self.min_lr = min_lr + self.current_step = last_step + 1 + self.step(self.current_step) + + def get_lr(self): + if self.warmup_steps > 0 and self.current_step <= self.warmup_steps: + return self.base_lr * self.current_step / self.warmup_steps + else: + ratio = (self.current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) + ratio = min(1.0, max(0.0, ratio)) + if self.decay_mode == 'linear': + return self.base_lr * (1 - ratio) + elif self.decay_mode == 'cosine': + return self.base_lr * (math.cos(math.pi * ratio) + 1.0) / 2.0 + else: + return self.base_lr + + def step(self, current_step=None): + if current_step is None: + current_step = self.current_step + 1 + self.current_step = current_step + new_lr = max(self.min_lr, self.get_lr()) + if isinstance(self.optimizer, list): + for o in self.optimizer: + for group in o.param_groups: + group['lr'] = new_lr + else: + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + def state_dict(self): + return { + 'base_lr': self.base_lr, + 'warmup_steps': self.warmup_steps, + 'total_steps': self.total_steps, + 'decay_mode': self.decay_mode, + 'current_step': self.current_step} + + def load_state_dict(self, state_dict): + self.base_lr = state_dict['base_lr'] + self.warmup_steps = state_dict['warmup_steps'] + self.total_steps = state_dict['total_steps'] + self.decay_mode = state_dict['decay_mode'] + self.current_step = state_dict['current_step'] diff --git a/utils/registry.py b/utils/registry.py new file mode 100644 index 0000000..b654ca9 --- /dev/null +++ b/utils/registry.py @@ -0,0 +1,167 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +# Registry class & build_from_config function partially modified from +# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py +# Copyright 2018-2020 Open-MMLab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import warnings + + +def build_from_config(cfg, registry, **kwargs): + """ Default builder function. + + Args: + cfg (dict): A dict which contains parameters passes to target class or function. + Must contains key 'type', indicates the target class or function name. + registry (Registry): An registry to search target class or function. + kwargs (dict, optional): Other params not in config dict. + + Returns: + Target class object or object returned by invoking function. + + Raises: + TypeError: + KeyError: + Exception: + """ + if not isinstance(cfg, dict): + raise TypeError(f"config must be type dict, got {type(cfg)}") + if "type" not in cfg: + raise KeyError(f"config must contain key type, got {cfg}") + if not isinstance(registry, Registry): + raise TypeError(f"registry must be type Registry, got {type(registry)}") + + cfg = copy.deepcopy(cfg) + + req_type = cfg.pop("type") + req_type_entry = req_type + if isinstance(req_type, str): + req_type_entry = registry.get(req_type) + if req_type_entry is None: + try: + print(f"For Windows users, we explicitly import registry function {req_type} !!!") + from tools.inferences.inference_unianimate_entrance import inference_unianimate_entrance + from tools.inferences.inference_unianimate_long_entrance import inference_unianimate_long_entrance + # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIM + # from tools.modules.diffusions.diffusion_ddim import DiffusionDDIMLong + # from tools.modules.autoencoder import AutoencoderKL + # from tools.modules.clip_embedder import FrozenOpenCLIPTextVisualEmbedder + # from tools.modules.unet.unet_unianimate import UNetSD_UniAnimate + + req_type_entry = eval(req_type) + except: + raise KeyError(f"{req_type} not found in {registry.name} registry") + + if kwargs is not None: + cfg.update(kwargs) + + if inspect.isclass(req_type_entry): + try: + return req_type_entry(**cfg) + except Exception as e: + raise Exception(f"Failed to init class {req_type_entry}, with {e}") + elif inspect.isfunction(req_type_entry): + try: + return req_type_entry(**cfg) + except Exception as e: + raise Exception(f"Failed to invoke function {req_type_entry}, with {e}") + else: + raise TypeError(f"type must be str or class, got {type(req_type_entry)}") + + +class Registry(object): + """ A registry maps key to classes or functions. + + Example: + >>> MODELS = Registry('MODELS') + >>> @MODELS.register_class() + >>> class ResNet(object): + >>> pass + >>> resnet = MODELS.build(dict(type="ResNet")) + >>> + >>> import torchvision + >>> @MODELS.register_function("InceptionV3") + >>> def get_inception_v3(pretrained=False, progress=True): + >>> return torchvision.models.inception_v3(pretrained=pretrained, progress=progress) + >>> inception_v3 = MODELS.build(dict(type='InceptionV3', pretrained=True)) + + Args: + name (str): Registry name. + build_func (func, None): Instance construct function. Default is build_from_config. + allow_types (tuple): Indicates how to construct the instance, by constructing class or invoking function. + """ + + def __init__(self, name, build_func=None, allow_types=("class", "function")): + self.name = name + self.allow_types = allow_types + self.class_map = {} + self.func_map = {} + self.build_func = build_func or build_from_config + + def get(self, req_type): + return self.class_map.get(req_type) or self.func_map.get(req_type) + + def build(self, *args, **kwargs): + return self.build_func(*args, **kwargs, registry=self) + + def register_class(self, name=None): + def _register(cls): + if not inspect.isclass(cls): + raise TypeError(f"Module must be type class, got {type(cls)}") + if "class" not in self.allow_types: + raise TypeError(f"Register {self.name} only allows type {self.allow_types}, got class") + module_name = name or cls.__name__ + if module_name in self.class_map: + warnings.warn(f"Class {module_name} already registered by {self.class_map[module_name]}, " + f"will be replaced by {cls}") + self.class_map[module_name] = cls + return cls + + return _register + + def register_function(self, name=None): + def _register(func): + if not inspect.isfunction(func): + raise TypeError(f"Registry must be type function, got {type(func)}") + if "function" not in self.allow_types: + raise TypeError(f"Registry {self.name} only allows type {self.allow_types}, got function") + func_name = name or func.__name__ + if func_name in self.class_map: + warnings.warn(f"Function {func_name} already registered by {self.func_map[func_name]}, " + f"will be replaced by {func}") + self.func_map[func_name] = func + return func + + return _register + + def _list(self): + keys = sorted(list(self.class_map.keys()) + list(self.func_map.keys())) + descriptions = [] + for key in keys: + if key in self.class_map: + descriptions.append(f"{key}: {self.class_map[key]}") + else: + descriptions.append( + f"{key}: ") + return "\n".join(descriptions) + + def __repr__(self): + description = self._list() + description = '\n'.join(['\t' + s for s in description.split('\n')]) + return f"{self.__class__.__name__} [{self.name}], \n" + description + + diff --git a/utils/registry_class.py b/utils/registry_class.py new file mode 100644 index 0000000..3a11ad6 --- /dev/null +++ b/utils/registry_class.py @@ -0,0 +1,19 @@ +from .registry import Registry, build_from_config + +def build_func(cfg, registry, **kwargs): + """ + Except for config, if passing a list of dataset config, then return the concat type of it + """ + return build_from_config(cfg, registry, **kwargs) + +AUTO_ENCODER = Registry("AUTO_ENCODER", build_func=build_func) +DATASETS = Registry("DATASETS", build_func=build_func) +DIFFUSION = Registry("DIFFUSION", build_func=build_func) +DISTRIBUTION = Registry("DISTRIBUTION", build_func=build_func) +EMBEDDER = Registry("EMBEDDER", build_func=build_func) +ENGINE = Registry("ENGINE", build_func=build_func) +INFER_ENGINE = Registry("INFER_ENGINE", build_func=build_func) +MODEL = Registry("MODEL", build_func=build_func) +PRETRAIN = Registry("PRETRAIN", build_func=build_func) +VISUAL = Registry("VISUAL", build_func=build_func) +EMBEDMANAGER = Registry("EMBEDMANAGER", build_func=build_func) diff --git a/utils/seed.py b/utils/seed.py new file mode 100644 index 0000000..4b45aec --- /dev/null +++ b/utils/seed.py @@ -0,0 +1,11 @@ +import torch +import random +import numpy as np + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True \ No newline at end of file diff --git a/utils/transforms.py b/utils/transforms.py new file mode 100644 index 0000000..bb07567 --- /dev/null +++ b/utils/transforms.py @@ -0,0 +1,353 @@ +import torch +import torchvision.transforms.functional as F +import random +import math +import numpy as np +from PIL import Image, ImageFilter + +__all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\ + 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"] + + +class Compose(object): + + def __init__(self, transforms): + self.transforms = transforms + + def __getitem__(self, index): + if isinstance(index, slice): + return Compose(self.transforms[index]) + else: + return self.transforms[index] + + def __len__(self): + return len(self.transforms) + + def __call__(self, rgb): + for t in self.transforms: + rgb = t(rgb) + return rgb + +class Resize(object): + + def __init__(self, size=256): + if isinstance(size, int): + size = (size, size) + self.size = size + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] + else: + rgb = rgb.resize(self.size, Image.BILINEAR) + return rgb + +class Rescale(object): + + def __init__(self, size=256, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, rgb): + w, h = rgb[0].size + scale = self.size / min(w, h) + out_w, out_h = int(round(w * scale)), int(round(h * scale)) + rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] + return rgb + +class CenterCrop(object): + + def __init__(self, size=224): + self.size = size + + def __call__(self, rgb): + w, h = rgb[0].size + assert min(w, h) >= self.size + x1 = (w - self.size) // 2 + y1 = (h - self.size) // 2 + rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] + return rgb + +class ResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] + scale = self.size_short / min(rgb[0].size) + rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] + out_w = self.size + out_h = self.size + w, h = rgb[0].size # (518, 292) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + # # center crop + # x1 = (img[0].width - self.size) // 2 + # y1 = (img[0].height - self.size) // 2 + # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return rgb + + + +class ExtractResizeRandomCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] + scale = self.size_short / min(rgb[0].size) + rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] + out_w = self.size + out_h = self.size + w, h = rgb[0].size # (518, 292) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + wh = [x1, y1, x1 + out_w, y1 + out_h] + return rgb, wh + + +class ExtractResizeAssignCrop(object): + + def __init__(self, size=256, size_short=292): + self.size = size + # self.min_area = min_area + self.size_short = size_short + + def __call__(self, rgb, wh): + + # consistent crop between rgb and m + while min(rgb[0].size) >= 2 * self.size_short: + rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] + scale = self.size_short / min(rgb[0].size) + rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] + + rgb = [u.crop(wh) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + return rgb + +class CenterCropV2(object): + def __init__(self, size): + self.size = size + + def __call__(self, img): + # fast resize + while min(img[0].size) >= 2 * self.size: + img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img] + scale = self.size / min(img[0].size) + img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img] + + # center crop + x1 = (img[0].width - self.size) // 2 + y1 = (img[0].height - self.size) // 2 + img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] + return img + + +class CenterCropWide(object): + def __init__(self, size, interpolation=Image.BOX): + self.size = size + self.interpolation = interpolation + + def __call__(self, img): + if isinstance(img, list): + scale = min(img[0].size[0]/self.size[0], img[0].size[1]/self.size[1]) + img = [u.resize((round(u.width // scale), round(u.height // scale)), resample=self.interpolation) for u in img] + + # center crop + x1 = (img[0].width - self.size[0]) // 2 + y1 = (img[0].height - self.size[1]) // 2 + img = [u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) for u in img] + return img + else: + scale = min(img.size[0]/self.size[0], img.size[1]/self.size[1]) + img = img.resize((round(img.width // scale), round(img.height // scale)), resample=self.interpolation) + x1 = (img.width - self.size[0]) // 2 + y1 = (img.height - self.size[1]) // 2 + img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) + return img + + + +class RandomCrop(object): + + def __init__(self, size=224, min_area=0.4): + self.size = size + self.min_area = min_area + + def __call__(self, rgb): + + # consistent crop between rgb and m + w, h = rgb[0].size + area = w * h + out_w, out_h = float('inf'), float('inf') + while out_w > w or out_h > h: + target_area = random.uniform(self.min_area, 1.0) * area + aspect_ratio = random.uniform(3. / 4., 4. / 3.) + out_w = int(round(math.sqrt(target_area * aspect_ratio))) + out_h = int(round(math.sqrt(target_area / aspect_ratio))) + x1 = random.randint(0, w - out_w) + y1 = random.randint(0, h - out_h) + + rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] + rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] + + return rgb + +class RandomCropV2(object): + + def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): + if isinstance(size, (tuple, list)): + self.size = size + else: + self.size = (size, size) + self.min_area = min_area + self.ratio = ratio + + def _get_params(self, img): + width, height = img.size + area = height * width + + for _ in range(10): + target_area = random.uniform(self.min_area, 1.0) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(self.ratio)): + w = width + h = int(round(w / min(self.ratio))) + elif (in_ratio > max(self.ratio)): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, rgb): + i, j, h, w = self._get_params(rgb[0]) + rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] + return rgb + +class RandomHFlip(object): + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] + return rgb + +class GaussianBlur(object): + + def __init__(self, sigmas=[0.1, 2.0], p=0.5): + self.sigmas = sigmas + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + sigma = random.uniform(*self.sigmas) + rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb] + return rgb + +class ColorJitter(object): + + def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + brightness, contrast, saturation, hue = self._random_params() + transforms = [ + lambda f: F.adjust_brightness(f, brightness), + lambda f: F.adjust_contrast(f, contrast), + lambda f: F.adjust_saturation(f, saturation), + lambda f: F.adjust_hue(f, hue)] + random.shuffle(transforms) + for t in transforms: + rgb = [t(u) for u in rgb] + + return rgb + + def _random_params(self): + brightness = random.uniform( + max(0, 1 - self.brightness), 1 + self.brightness) + contrast = random.uniform( + max(0, 1 - self.contrast), 1 + self.contrast) + saturation = random.uniform( + max(0, 1 - self.saturation), 1 + self.saturation) + hue = random.uniform(-self.hue, self.hue) + return brightness, contrast, saturation, hue + +class RandomGray(object): + + def __init__(self, p=0.2): + self.p = p + + def __call__(self, rgb): + if random.random() < self.p: + rgb = [u.convert('L').convert('RGB') for u in rgb] + return rgb + +class ToTensor(object): + + def __call__(self, rgb): + if isinstance(rgb, list): + rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) + else: + rgb = F.to_tensor(rgb) + + return rgb + +class Normalize(object): + + def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, rgb): + rgb = rgb.clone() + rgb.clamp_(0, 1) + if not isinstance(self.mean, torch.Tensor): + self.mean = rgb.new_tensor(self.mean).view(-1) + if not isinstance(self.std, torch.Tensor): + self.std = rgb.new_tensor(self.std).view(-1) + if rgb.dim() == 4: + rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) + elif rgb.dim() == 3: + rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) + return rgb + diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..b5c0a61 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,16 @@ +import torch + +def to_device(batch, device, non_blocking=False): + if isinstance(batch, (list, tuple)): + return type(batch)([ + to_device(u, device, non_blocking) + for u in batch]) + elif isinstance(batch, dict): + return type(batch)([ + (k, to_device(v, device, non_blocking)) + for k, v in batch.items()]) + elif isinstance(batch, torch.Tensor) and batch.device != device: + batch = batch.to(device, non_blocking=non_blocking) + else: + return batch + return batch diff --git a/utils/video_op.py b/utils/video_op.py new file mode 100644 index 0000000..399e9f4 --- /dev/null +++ b/utils/video_op.py @@ -0,0 +1,359 @@ +import os +import os.path as osp +import sys +import cv2 +import glob +import math +import torch +import gzip +import copy +import time +import json +import pickle +import base64 +import imageio +import hashlib +import requests +import binascii +import zipfile +# import skvideo.io +import numpy as np +from io import BytesIO +import urllib.request +import torch.nn.functional as F +import torchvision.utils as tvutils +from multiprocessing.pool import ThreadPool as Pool +from einops import rearrange +from PIL import Image, ImageDraw, ImageFont + + +def gen_text_image(captions, text_size): + num_char = int(38 * (text_size / text_size)) + font_size = int(text_size / 20) + font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) + text_image_list = [] + for text in captions: + txt_img = Image.new("RGB", (text_size, text_size), color="white") + draw = ImageDraw.Draw(txt_img) + lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) + draw.text((0, 0), lines, fill="black", font=font) + txt_img = np.array(txt_img) + text_image_list.append(txt_img) + text_images = np.stack(text_image_list, axis=0) + text_images = torch.from_numpy(text_images) + return text_images + +@torch.no_grad() +def save_video_refimg_and_text( + local_path, + ref_frame, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + nrow=4, + save_fps=8, + retry=5): + ''' + gen_video: BxCxFxHxW + ''' + nrow = max(int(gen_video.size(0) / 2), 1) + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 + text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 + text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 + + ref_frame = ref_frame.unsqueeze(2) + ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) + ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 + ref_frame.clamp_(0, 1) + ref_frame = ref_frame * 255.0 + ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = torch.cat([ref_frame, images, text_images], dim=3) + + images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) + images = [(img.numpy()).astype('uint8') for img in images] + + for _ in [None] * retry: + try: + if len(images) == 1: + local_path = local_path + '.png' + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + else: + local_path = local_path + '.mp4' + frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) + os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) + for fid, frame in enumerate(images): + tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) + cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd); os.system(f'rm -rf {frame_dir}') + # os.system(f'rm -rf {local_path}') + exception = None + break + except Exception as e: + exception = e + continue + + +@torch.no_grad() +def save_i2vgen_video( + local_path, + image_id, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + retry=5, + save_fps = 8 +): + ''' + Save both the generated video and the input conditions. + ''' + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 + text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 + text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 + + image_id = image_id.unsqueeze(2) # B, C, F, H, W + image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 + image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 + image_id.clamp_(0, 1) + image_id = image_id * 255.0 + image_id = rearrange(image_id, 'b c f h w -> b f h w c') + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = torch.cat([image_id, images, text_images], dim=3) + images = images[0] + images = [(img.numpy()).astype('uint8') for img in images] + + exception = None + for _ in [None] * retry: + try: + frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) + os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) + for fid, frame in enumerate(images): + tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) + cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd); os.system(f'rm -rf {frame_dir}') + break + except Exception as e: + exception = e + continue + + if exception is not None: + raise exception + + +@torch.no_grad() +def save_i2vgen_video_safe( + local_path, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + retry=5, + save_fps = 8 +): + ''' + Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. + ''' + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = images[0] + images = [(img.numpy()).astype('uint8') for img in images] + num_image = len(images) + exception = None + for _ in [None] * retry: + try: + if num_image == 1: + local_path = local_path + '.png' + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + else: + writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) + for fid, frame in enumerate(images): + if fid == num_image-1: # Fix known bugs. + ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) + if ratio > 0.4: continue + writer.append_data(frame) + writer.close() + break + except Exception as e: + exception = e + continue + + if exception is not None: + raise exception + + +@torch.no_grad() +def save_t2vhigen_video_safe( + local_path, + gen_video, + captions, + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5], + text_size=256, + retry=5, + save_fps = 8 +): + ''' + Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. + ''' + vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw + + gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 + gen_video.clamp_(0, 1) + gen_video = gen_video * 255.0 + + images = rearrange(gen_video, 'b c f h w -> b f h w c') + images = images[0] + images = [(img.numpy()).astype('uint8') for img in images] + num_image = len(images) + exception = None + for _ in [None] * retry: + try: + if num_image == 1: + local_path = local_path + '.png' + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + else: + frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) + os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) + for fid, frame in enumerate(images): + if fid == num_image-1: # Fix known bugs. + ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) + if ratio > 0.4: continue + tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) + cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' + os.system(cmd) + os.system(f'rm -rf {frame_dir}') + break + except Exception as e: + exception = e + continue + + if exception is not None: + raise exception + + + + +@torch.no_grad() +def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, + mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8): + mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw + std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw + video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1] + video_tensor.clamp_(0, 1) + + b, c, n, h, w = video_tensor.shape + source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) + source_imgs = source_imgs.cpu() + + model_kwargs_channel3 = {} + for key, conditions in model_kwargs[0].items(): + + + if conditions.size(1) == 1: + conditions = torch.cat([conditions, conditions, conditions], dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + if conditions.size(1) == 2: + conditions = torch.cat([conditions, conditions[:,:1,]], dim=1) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 3: + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + elif conditions.size(1) == 4: # means it is a mask. + color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32) + alpha = conditions[:, 3:4] # .astype(np.float32) + conditions = color * alpha + 1.0 * (1.0 - alpha) + conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) + model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions + + # filename = rand_name(suffix='.gif') + for _ in [None] * retry: + try: + vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow) + + # cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()] + # vid_gif = torch.cat(cons_list + [vid_gif,], dim=3) #Uncomment this and previous line to compare output video with input pose frames + + vid_gif = vid_gif.permute(1,2,3,0) + + images = vid_gif * 255.0 + images = [(img.numpy()).astype('uint8') for img in images] + if len(images) == 1: + + local_path = local_path.replace('.mp4', '.png') + cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) + # bucket.put_object_from_file(oss_key, local_path) + else: + + outputs = [] + for image_name in images: + x = Image.fromarray(image_name) + outputs.append(x) + from pathlib import Path + save_fmt = Path(local_path).suffix + + if save_fmt == ".mp4": + with imageio.get_writer(local_path, fps=save_fps) as writer: + for img in outputs: + img_array = np.array(img) # Convert PIL Image to numpy array + writer.append_data(img_array) + + elif save_fmt == ".gif": + outputs[0].save( + fp=local_path, + format="GIF", + append_images=outputs[1:], + save_all=True, + duration=(1 / save_fps * 1000), + loop=0, + ) + else: + raise ValueError("Unsupported file type. Use .mp4 or .gif.") + + # fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # fps = save_fps + # image = images[0] + # media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0])) + # for image_name in images: + # im = image_name[:,:,::-1] + # media_writer.write(im) + # media_writer.release() + + + exception = None + break + except Exception as e: + exception = e + continue + if exception is not None: + print('save video to {} failed, error: {}'.format(local_path, exception), flush=True) + From 00a5b59d93eea836808d5e69695b1f1c1329e36f Mon Sep 17 00:00:00 2001 From: Isi <86603298+Isi-dev@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:39:54 +0100 Subject: [PATCH 4/5] Add files via upload --- tools/__init__.py | 3 + tools/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 256 bytes tools/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 200 bytes tools/datasets/__init__.py | 2 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 254 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 198 bytes .../__pycache__/image_dataset.cpython-310.pyc | Bin 0 -> 2897 bytes .../__pycache__/image_dataset.cpython-39.pyc | Bin 0 -> 2846 bytes .../__pycache__/video_dataset.cpython-310.pyc | Bin 0 -> 3464 bytes .../__pycache__/video_dataset.cpython-39.pyc | Bin 0 -> 3402 bytes tools/datasets/image_dataset.py | 86 + tools/datasets/video_dataset.py | 118 ++ tools/inferences/__init__.py | 2 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 293 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 237 bytes ...erence_unianimate_entrance.cpython-310.pyc | Bin 0 -> 12891 bytes ...ference_unianimate_entrance.cpython-39.pyc | Bin 0 -> 12529 bytes ...e_unianimate_long_entrance.cpython-310.pyc | Bin 0 -> 13012 bytes ...ce_unianimate_long_entrance.cpython-39.pyc | Bin 0 -> 13136 bytes .../inference_unianimate_entrance.py | 546 ++++++ .../inference_unianimate_long_entrance.py | 508 +++++ tools/modules/__init__.py | 7 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 438 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 382 bytes .../__pycache__/autoencoder.cpython-310.pyc | Bin 0 -> 16687 bytes .../__pycache__/autoencoder.cpython-39.pyc | Bin 0 -> 17098 bytes .../__pycache__/clip_embedder.cpython-310.pyc | Bin 0 -> 6728 bytes .../__pycache__/clip_embedder.cpython-39.pyc | Bin 0 -> 6645 bytes .../__pycache__/config.cpython-310.pyc | Bin 0 -> 3845 bytes .../modules/__pycache__/config.cpython-39.pyc | Bin 0 -> 3696 bytes .../embedding_manager.cpython-310.pyc | Bin 0 -> 5244 bytes .../embedding_manager.cpython-39.pyc | Bin 0 -> 5161 bytes tools/modules/autoencoder.py | 698 +++++++ tools/modules/clip_embedder.py | 241 +++ tools/modules/config.py | 206 ++ tools/modules/diffusions/__init__.py | 1 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 240 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 184 bytes .../diffusion_ddim.cpython-310.pyc | Bin 0 -> 24748 bytes .../__pycache__/diffusion_ddim.cpython-39.pyc | Bin 0 -> 28560 bytes .../__pycache__/losses.cpython-310.pyc | Bin 0 -> 1337 bytes .../__pycache__/losses.cpython-39.pyc | Bin 0 -> 1279 bytes .../__pycache__/schedules.cpython-310.pyc | Bin 0 -> 4726 bytes .../__pycache__/schedules.cpython-39.pyc | Bin 0 -> 4653 bytes tools/modules/diffusions/diffusion_ddim.py | 1121 +++++++++++ tools/modules/diffusions/diffusion_gauss.py | 498 +++++ tools/modules/diffusions/losses.py | 28 + tools/modules/diffusions/schedules.py | 166 ++ tools/modules/embedding_manager.py | 179 ++ tools/modules/unet/__init__.py | 2 + .../unet/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 235 bytes .../unet/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 179 bytes .../unet_unianimate.cpython-310.pyc | Bin 0 -> 15790 bytes .../unet_unianimate.cpython-39.pyc | Bin 0 -> 15915 bytes .../unet/__pycache__/util.cpython-310.pyc | Bin 0 -> 42320 bytes .../unet/__pycache__/util.cpython-39.pyc | Bin 0 -> 44250 bytes tools/modules/unet/mha_flash.py | 103 + tools/modules/unet/unet_unianimate.py | 659 +++++++ tools/modules/unet/util.py | 1741 +++++++++++++++++ 59 files changed, 6915 insertions(+) create mode 100644 tools/__init__.py create mode 100644 tools/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/datasets/__init__.py create mode 100644 tools/datasets/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/datasets/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/datasets/__pycache__/image_dataset.cpython-310.pyc create mode 100644 tools/datasets/__pycache__/image_dataset.cpython-39.pyc create mode 100644 tools/datasets/__pycache__/video_dataset.cpython-310.pyc create mode 100644 tools/datasets/__pycache__/video_dataset.cpython-39.pyc create mode 100644 tools/datasets/image_dataset.py create mode 100644 tools/datasets/video_dataset.py create mode 100644 tools/inferences/__init__.py create mode 100644 tools/inferences/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/inferences/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-39.pyc create mode 100644 tools/inferences/inference_unianimate_entrance.py create mode 100644 tools/inferences/inference_unianimate_long_entrance.py create mode 100644 tools/modules/__init__.py create mode 100644 tools/modules/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/modules/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/modules/__pycache__/autoencoder.cpython-310.pyc create mode 100644 tools/modules/__pycache__/autoencoder.cpython-39.pyc create mode 100644 tools/modules/__pycache__/clip_embedder.cpython-310.pyc create mode 100644 tools/modules/__pycache__/clip_embedder.cpython-39.pyc create mode 100644 tools/modules/__pycache__/config.cpython-310.pyc create mode 100644 tools/modules/__pycache__/config.cpython-39.pyc create mode 100644 tools/modules/__pycache__/embedding_manager.cpython-310.pyc create mode 100644 tools/modules/__pycache__/embedding_manager.cpython-39.pyc create mode 100644 tools/modules/autoencoder.py create mode 100644 tools/modules/clip_embedder.py create mode 100644 tools/modules/config.py create mode 100644 tools/modules/diffusions/__init__.py create mode 100644 tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc create mode 100644 tools/modules/diffusions/__pycache__/losses.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/losses.cpython-39.pyc create mode 100644 tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc create mode 100644 tools/modules/diffusions/diffusion_ddim.py create mode 100644 tools/modules/diffusions/diffusion_gauss.py create mode 100644 tools/modules/diffusions/losses.py create mode 100644 tools/modules/diffusions/schedules.py create mode 100644 tools/modules/embedding_manager.py create mode 100644 tools/modules/unet/__init__.py create mode 100644 tools/modules/unet/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/modules/unet/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc create mode 100644 tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc create mode 100644 tools/modules/unet/__pycache__/util.cpython-310.pyc create mode 100644 tools/modules/unet/__pycache__/util.cpython-39.pyc create mode 100644 tools/modules/unet/mha_flash.py create mode 100644 tools/modules/unet/unet_unianimate.py create mode 100644 tools/modules/unet/util.py diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..33ef13c --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,3 @@ +from .datasets import * +from .modules import * +from .inferences import * diff --git a/tools/__pycache__/__init__.cpython-310.pyc b/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74bb8278b4d5ad8bfbcfbf761c96156e10650dc3 GIT binary patch literal 256 zcmd1j<>g`kf|lOtX+A*uF^GcyfROD#&xOHM6b$xy@!R1GG6B|2Nh zgche36~|;2XJ+NcIOpf4Rfc-TmuKds!Q&Njz zg7KkwnT~mxxrrsIF(vu=ImI#Y@tJva4_;P0sy>* BM;8D9 literal 0 HcmV?d00001 diff --git a/tools/__pycache__/__init__.cpython-39.pyc b/tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f55644d5d92c88f526b3325d5918d3feaaf5e5e GIT binary patch literal 200 zcmYe~<>g`kf|lOtX+A*uF^GcyfROD#&xOHM6b$xy@!R1GG6X**lR zgche36~|;2XJ+Ncgyv;B=4Iw4mZZj%Q(0su{mF-ZUb literal 0 HcmV?d00001 diff --git a/tools/datasets/__init__.py b/tools/datasets/__init__.py new file mode 100644 index 0000000..f1b217f --- /dev/null +++ b/tools/datasets/__init__.py @@ -0,0 +1,2 @@ +from .image_dataset import * +from .video_dataset import * diff --git a/tools/datasets/__pycache__/__init__.cpython-310.pyc b/tools/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be3d985797983558a345dab5b5d3699f7a547167 GIT binary patch literal 254 zcmYjLy9&ZU5WGt)L`>&T#KI2{5yi^JM$^c_61|I@B)1`ZD)zS4euKYcYvnIkxjq68 z>A}i^W$ZYcLma_>dgM@+P~pI9NuuNH+#_ z>Qgd)8SSk$Slg?pLO%Y`b35s)giJ-slHH=3V38e@DIBcRcpxaT^64r!b)eQe@>qaF sZKz&C;_TRQr#H$Y=A?ejqyVh59-vnqZ_Po0dY7Bw`&Vi(HAP##0GRAY`2YX_ literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/__init__.cpython-39.pyc b/tools/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6726b131c60fac9c8a6d9607115635f43748edb GIT binary patch literal 198 zcmYe~<>g`kf|lOtX*NLmF^Gc@X literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/image_dataset.cpython-310.pyc b/tools/datasets/__pycache__/image_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f0b69b97e1f5c03a02280cadce19a15302555f GIT binary patch literal 2897 zcmZuzTW{RP6`mOmughIavMo!AkqSs+bkla*=F%pqV8pS5#6X|~aSS611q8Fh-B9Lb zJ)|sYGYkyKMccRbAM8FA{u%uXMW5!WKwsj25cnbJIm2B?HcVp9TzTfqneTjuWw#p; zc>eY8e~ur8g!~UXXCE6nA3&-84nhz?Gcs&zI;5u0hRoE~&@we2a#Pzw+tkj`g_>nv z?hQSP-&&cU2gAU$dDh9pVQAWR*3Emv9<&|dX6N#?;Tk2$+I)RMJ|@Bw{zD@Cq<3na z(&2`nw@5#D3|5i7KGl3wPEVk=Z|3nNfrk6xNu5+TKf~sYxQ?r&er&;F=zBM=eR=KH zM_=9=;n!y$2Ra`>sjh-Z2%Zp5M+FmX!h|Kbu!TK#(7uNh&TDLf^9G{xkidCE(S_O( zHrUqFp|NGTdHMtF{`kM3x6jfC6xsv@6tXw+L9U)IG&qzFlEgwxWuivK!>m+{!?iS$wNTgEY(Vmpo-`~GoC9>L2tMs7Uzh36!liN3=qqGp^Q58)~ zS;yaG$!cPMbXe789u=iXs?}h0yGXAUX&%?feqEMXwU3*?9fMW`K(y=*rYAbs&C;qK zmHG7QYj_Gg(iWNZ&)&`;iqayjqiC?3m7_STu0mt=RX@-{9v??z8Rto*1DRB1c37un zp@Uk+MKvyEUg_R_P;h3>^v}#&vvjI?IZX=f%On;fi>t=sL$y*w0$|-pgp=lLmJT8#TWYR8PFhr(973R74l0BpgEwY7_xK^ z;9+G+VeINt`3snDjX`*{#<)=gE28L^r~oy=z89X`*DM&}|P`XVZ zCRk%BYfc4SKqPFFhR;Gp=1ehVE!aM9?8a%_#=Ap`PdKTqlW!<#d`f1!HJ{tUnmZ74 z{(@n|2;O*5c8&Nw&&D?2avjyLWd@)hu757MeTbIL#6<7z}of()!z-M9NaW zzw=$Ai9~9X$SM8McphSPnx(Zq7s2;Y3qHh{_CGovB~ycLH!CNTw3t9B7voYpi1WDC zc3sL5;Aob}vTC7cK#2J32j)Vx4bhw&%MBRP{ur!kvFE1wB)XfNU<8cYfYLt3jX8i0 z?x!LtP4H@W6t`B;A~f0{I{FptP+bPm1Bf!ZWq`^lr!T?#2K64WCp~wz^&EJf zem(+Le){yFg)TAx`UD&d^Tnm+i$=56bRaO+DcDVo1utIn6{$BtE9?m=rmMQD zqkLtLIe^yRlUKm8k1>^K>x+*+ob>^@$bviO%3A<0@%tQje}=BX{GfkHA}?rGR+<|| z;|FD0Xd9?_l1RKz`7$oxIiw80QCZwiWUZZe3fwE?8@K>rBY%zJHz)=uev878sG(C# z7KTY%PBjomy6N{?&O&U<--8$i2EB%v<(oMA2N0S@5Cxcwsc5k--^PB} zehC?%Dy4i2*BKUkN`q8A-Ud%LYat|H+YrtydeUS7T6-{tGDP5OVEA;9(k literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/image_dataset.cpython-39.pyc b/tools/datasets/__pycache__/image_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7ce1cb7c2a2edf56f6cbf1b737da6f7c12a0898 GIT binary patch literal 2846 zcmZ`*OOM>f5${iE2m<(jO$IUI_|`uC9Kty1M$SuNn7x zA;WX>=f6(=-ev4x)HwS%XuOM3K7o)-@{Enznvb|?i;*z3HL^@?k8D#rBgfS4$U`kM zf98*TPTyKtFbhYa(e11=>yElcce38BKk6IZ&Cbm>MjM}!NuBJz<;l>)Zdw{O4y&fceGqRPFrO!teuPxJJ4p3dSb*{g~oEB9zR z4hrLRiqX0`oFD0MCris}T+HUTUWcjwps>Z@Y?Oyll;&v_MZ=w}7{^(83(Dy?gHVUF z_%NENc$SnpR7qK65397uby%r5FDHeXmAd~32F{F`_L*^OkUlwL!Bl&}{!S0@PKKJ=Ssg`V}_6!22j&zPS9-xIK6H1Ox)W z`zR$*M1{lO=f{+Fd>6-TWl3)Q8gTUkjJM9vF}KcmQG^*$bb}0#CB_$*URc*m7@I+w zfQMG1VwU|U6f6p~U1h+wXsp^=a>-Xf^i|f_i_MxXMJ;M;CHCybY23zZ{QE3_VzbIR z`UhuCz}e!kvX_pumM(H9SP4SDv>PAAuDqo$kzM_qn6$N>5{mW)+np9gYG`A*k znk8~L9Bw^Mt1q?^r3&@V)|ZWDPr)KgQFI1f6M5v^JWDHmF2eUw3qHb_4n8;>Cv$^t zFDs_gG@l|R^GTsy!g*Y2rz+GKI9epCC|l?m5EA|dp*c|Pq`6EE)p>O3U;?jN?0M;I z8a+skC<7*KLg|3=#%w@`k5ZWwCUvzpj$1Ey@*+Gs5jkxT9sLL^7=mE{XrBwdX@F{T zn_t8G7GG?>0Gwx^jwvRmPag_VrvcAz5nK?YRf{EVtrZ7=0FjFCi=_q1v{wLBWlM3) zL7j`nUEHi3Km_n;yxQCMr1ij3AJqg9%k){xp!Ntn_rcx=f4k2PZUAQC(GoDjy|#Fx z>MXm`UiJVpzlGTv4E|UC!Hu&zfPp>&#=(uMzdQ#xZ5TM6UxV$u^Z-M<2^PPpF4Va1 zr1!vn!j)ZHG}3^!mlx|0es!Q-LIM!6cWY32JOB4W`V+Ck@@r(@*uK2Bec9Nynhp}> zI>%l(`~P!3W4YL7)fHIAy0|`ib+7K!LG4UH{DWXuT}Qy5(lw#2Uw`=CVgNJ~Gj16e zw4h#N`UOBMVg>_elLy1WD{4SyL{Vznu$;YLq`7uL&(lO1@PCgc*t0?=z*dnzN>ruY zcn&_6icqB9B=KVsLlQqBu}#7dtD#uBJ`B6IJZoT26g23!JVtm|Hz7u$L9}6O^;7Eo z1%&1iGJ~!s^-JpeIn4}asciADen#zX`z3CLvQX+RnrE2t89k~u(DybiRosi!J z_QPi{k@f51RrK|_Z+fqwk6)yiDWyr2; literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/video_dataset.cpython-310.pyc b/tools/datasets/__pycache__/video_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c0008633bfee0df40a629989d67d77ca21793d GIT binary patch literal 3464 zcmZ`*TW=gm6|SnjOn1-Kj;~2JYqDsiS%e9J)k5$_>vb+fERL}gq8)kFbf$V{+wSR} zRJZMTwt5xGK?@Quh(9oXKq$X~e}NDW^@N0w@E-sv4~TQB$FUQIZuM8!u2Yxqe5cZS z-6Qb4{L7zuf14uYA6Pj5SWx%`zU+4ZoNyYGPMOjU)n(RUIyE|mPR)*~Q>$a?)b7|i zt#lknS?mtnj!SV2BlZT>PF3g4xHhPF>X5g%9XAHePLq&*<~B@dO4WWD z@i09#{x2B(1&f|hLQV+43wH8%up(*zCoKdAVEslBcU5#rfWC0A7-J*wxcZCNw?S1LGSS4 zJ^vs|czTfe!&KzK<2XER*zWFSc{=cul!w`ArT;LAR+DHDU|PO~{=7)vhad5Ipq{ zQo%FjVYVlNK?r#fW@)^aM`@xO`|$s{vX=#eVI2CsA@s{dkYv4746?G{xpK1~=6){{ z86b|Tj8QaHW;zTL<%%%iag=~5JA9gl30EF0x*z&Dxpq^TSs3?}t9KZyRQ+_bbq_HQ zYZZ$K7ZDZ^E+Je-cn9GMfU=U|1xbi|26qpcvt|24D7907-U$Sc?Q`%nCyqW6J3bF*;ZGxN&W z;owd?F1HuamWv}F8-?|%#hbKvlNK9Wpl9_h|CvSQ5&0aXt#H3~6Ocbl59k2_m$kuR z9clI~aI4DtOG8R2(6L7Ee^o5O(Jz%$TIvN+9P;Jm<)wot|8gl5A{955zA6+EsKqDv zPrFtgkd+PNC|6A%sdL76<=%eU4TqYIl{oG9qofaBnDkO*Blm(_S$QhDaIB+Hq}dz9 zXnW$@U_@Eqh~ZO#POjV@%&MeI850RF=30a}`l(*iwnN+$#IS(3^zD$C3^;PS!^Q^aBao(GxyjiP1)cmuVU6~43rGS4DQszvE4j7fK)y1`yO4Ks4|R8^ zdI?q2Mcv)0&5*)_qXmM@lU*V{=0;JKPEku5d41gA2DU9g=E%ed0*CL6~^vA*??S)TMJ~@5=n9q>lcW0 z=OCKyGVu+ppe0){^C_^^%BRN{WDDlql8x~U=0M+n@d$!ZatT~y_VAD3B2(ZZ&3tY= z&+YL7L>u@M`%Q>E_4U+~X&^GBNaTERd=b~oaThLeOFPXaZh+I6#Wd7i7Jr8IUr3z% zo&1Vy98vD*y}B^EHYppsrub+6u58E)i@1w3;KWOE24*=6xo=?O+ zu@#ck(Ym&Lbz=Fdwrmu$;PkK#eE)x+&k0zan^>Kf^Kw>BgF~9&klwoZ4p8OF} zSlDU#d71}tsULo_H&Aw#3kcesW;fjr#kq!RmIvZgUR;>(M06@c*w;K1%G2VhL!B}N zY^J3v#K&m=CjgxaN>#+4LIjT{zl$3jU~K807!Ex;M=vt-nN2U_h-`HE zb$NST(_g!UVGO|i*Aj=}GRN3OkMLSd*V{~NL7ff}ZCW15>kCsA--k2w{S5l)dk{qD znb*yi9omk-OR8`YP*p`8p#iYIey3f*D?kF572j8$?+;SG2VufjHQ#@-7sO?Yz}uF< zebgH&e9WCq=xnf49VmWW-*8bXRr5f=Ze#;eJHSBds?hQB5ngHb0@7u^3?GSsY9#>TzRN>6C%Ft`jCw>VqI8 s#zB?|8%J%Fl|aw366f-^#VBW`+>M)fSUR>j!~u$&y+|W`ZDWr94}e``0ssI2 literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/video_dataset.cpython-39.pyc b/tools/datasets/__pycache__/video_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb6e95f5e8a0b7d2189a9ea436324b30bf34686 GIT binary patch literal 3402 zcmZ`*OK%*<5$^7J&pt^}d`Pmw*o1*qkd30)NvtT0B~lU$C={WG1uv)!#@n;&A!lcn zJwu6H_ZR`lAPJBFiZ8h?b<^Jnf}8``r#>0T#XKj-!KYNsEG1GwS8@$ zcBVMJPg?d7tca|&sM5R1a1V0hqk-QKpkS@~nV$w36z$dP53b+8`QUyRd!PJlsC)=1 zZ36@ZcLO_Va+^B`6!zls$^n7BRCx_@gV%Wja+8~&drDP54tbDF^c;Pw5%PkCzomp6 z5r7x$=o^?-Q~-{e2&irCRu=`wUjqs{B$TBRKx&eBM4Et9>VQldfSR-bb!h_{(g8H3 z3uyI8Ud?PtCEI7KABGDp@p-DhL>kmU8>lOE-aYX`QhuabS5*JX8Fk@Z7L z$@a3W!aP+@U`zU%qM{3HfB7*e;rh-%N@oE2M3TqVoc@7G+Gl@Z+xmU-f(YD4iz(K= ztn@Go1~1^$Ib49pLs2w-_rc25jfZI<(v2_;H+b(6tx`4Z^)HzYD@ojm-u=Y?^Yd0zWklyvWD^LoK!A=c*Q7S4oE`!R7f7y; z<|P?3$)tA3HuT)c&D_fEMF4YSV{10fP|b@B6Dkt6U61=v=!_jEqfEPg7lcT?unIHUB6rh4NCH$eV62^c)09> zV&0EHnj$F)aHUQeoi8|}Q=OiJ|2;aIe~mjYhj#H)jz63NJzznfRUnpRV+b>c6zHSD zb<<5Y*1*RBVRDw~obA(9@{B%S&W+K3G6N_A^vF$Vj;+ktw4rQG$_|w6%ti0HSvils z&q43GS)C!d3l|IYm?K+6NUr4-Y3J4WLRK5sxrS}?5US=0=&ZD|bi;A{1tfmAlc-(-`M9U}hmjBd}!nykRw|J;5Kr7G9w$(A9$Ic}LJ(wT(- zxW&YC*h5n`VeM0(t(i@a&&VdMyeaGB87zUHe*Or83XJVNBHK$~C}*G1y-&eZrodD# zWwYZsZj9$4dcX|We?sV~FQ&GaPL_e(AVYF*FJ%kkb9fdUqRr-ct{0|so@-!AdOi)Z z3(u(d5>Dn!+|1s}F3S2oe`kLy}r%9C+qUe0-o>;nD?Tbf%Tt-()TgW zCPIFR@`5ohSk6k12Ib1=Z{q2eNw+roBeiq!pyA*az$^Y;335nTd0-9$hn&;%4 zJS(TcTp=p8oHg+SAk`5*UkaNQ80AMTFkJM@iv{0H`<YGUzD+2;UAAElI$(6}-UZuBt27*S+<7ch|jJ>-W{ntJ=NmzrN|M z+B9Fy^%8 z3$IeR6Y+Cw{{Z2a0E&95a1^lJj*fw{)AX_=^P|$ac=66anQ0~Z z;wpA6T-QtgB`$*ur$|2rDEt;f&*eg-tAh8!T&E_Rp*C|_mD;okvL^gxTNO&vbctPr z7Lzq$WF1Dh@XvzGq>Hr5Mi)+dz0+>~)cXp<0-k!|`7nZJ7|!TEK5pq+i-`vyE22h= zmKSk)JIeArxH->DpdCK*? zLBe++a(JrhdD}aFRJMpqIOJCdc#PsV2(K3O>X1!nz*-R;Y@@K_(wY(f>t3`Q_WJ^R zwWh@loaGjP(l^s2E;)J!g-Bjy!)s=UaZDLND0mLCcvvu1Y3m=|DcRpF;$j 0: + mid_frame = frame_list[0] + vit_frame = self.vit_transforms(mid_frame) + frame_tensor = self.transforms(frame_list) + video_data[:len(frame_list), ...] = frame_tensor + else: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + except: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + ref_frame = copy(video_data[0]) + + return ref_frame, vit_frame, video_data, caption + diff --git a/tools/datasets/video_dataset.py b/tools/datasets/video_dataset.py new file mode 100644 index 0000000..cdc45de --- /dev/null +++ b/tools/datasets/video_dataset.py @@ -0,0 +1,118 @@ +import os +import cv2 +import json +import torch +import random +import logging +import tempfile +import numpy as np +from copy import copy +from PIL import Image +from torch.utils.data import Dataset +from ...utils.registry_class import DATASETS + + +@DATASETS.register_class() +class VideoDataset(Dataset): + def __init__(self, + data_list, + data_dir_list, + max_words=1000, + resolution=(384, 256), + vit_resolution=(224, 224), + max_frames=16, + sample_fps=8, + transforms=None, + vit_transforms=None, + get_first_frame=False, + **kwargs): + + self.max_words = max_words + self.max_frames = max_frames + self.resolution = resolution + self.vit_resolution = vit_resolution + self.sample_fps = sample_fps + self.transforms = transforms + self.vit_transforms = vit_transforms + self.get_first_frame = get_first_frame + + image_list = [] + for item_path, data_dir in zip(data_list, data_dir_list): + lines = open(item_path, 'r').readlines() + lines = [[data_dir, item] for item in lines] + image_list.extend(lines) + self.image_list = image_list + + + def __getitem__(self, index): + data_dir, file_path = self.image_list[index] + video_key = file_path.split('|||')[0] + try: + ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path) + except Exception as e: + logging.info('{} get frames failed... with error: {}'.format(video_key, e)) + caption = '' + video_key = '' + ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + return ref_frame, vit_frame, video_data, caption, video_key + + + def _get_video_data(self, data_dir, file_path): + video_key, caption = file_path.split('|||') + file_path = os.path.join(data_dir, video_key) + + for _ in range(5): + try: + capture = cv2.VideoCapture(file_path) + _fps = capture.get(cv2.CAP_PROP_FPS) + _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) + stride = round(_fps / self.sample_fps) + cover_frame_num = (stride * self.max_frames) + if _total_frame_num < cover_frame_num + 5: + start_frame = 0 + end_frame = _total_frame_num + else: + start_frame = random.randint(0, _total_frame_num-cover_frame_num-5) + end_frame = start_frame + cover_frame_num + + pointer, frame_list = 0, [] + while(True): + ret, frame = capture.read() + pointer +=1 + if (not ret) or (frame is None): break + if pointer < start_frame: continue + if pointer >= end_frame - 1: break + if (pointer - start_frame) % stride == 0: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frame_list.append(frame) + break + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(video_key, e)) + continue + + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + if self.get_first_frame: + ref_idx = 0 + else: + ref_idx = int(len(frame_list)/2) + try: + if len(frame_list)>0: + mid_frame = copy(frame_list[ref_idx]) + vit_frame = self.vit_transforms(mid_frame) + frames = self.transforms(frame_list) + video_data[:len(frame_list), ...] = frames + else: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + except: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + ref_frame = copy(frames[ref_idx]) + + return ref_frame, vit_frame, video_data, caption + + def __len__(self): + return len(self.image_list) + + diff --git a/tools/inferences/__init__.py b/tools/inferences/__init__.py new file mode 100644 index 0000000..db0383b --- /dev/null +++ b/tools/inferences/__init__.py @@ -0,0 +1,2 @@ +from .inference_unianimate_entrance import * +from .inference_unianimate_long_entrance import * diff --git a/tools/inferences/__pycache__/__init__.cpython-310.pyc b/tools/inferences/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac82efe421b0620827d06c2f12dc67c1a58cf6c GIT binary patch literal 293 zcmZusu}T9$5Zyf$f+U@-jon$~1A+u4m5q&Fw=k^PIa%482|Ig)^tRT1gTHjGmA_!+ z#BgHagL#jcH#}x|Haj(nm)Ez}tA6a8zo@ji){--&V1iBN;h)6?(HyU8b-_(NZ5wQ!AK78fNxvs-c3?wVWJ zMhd(Y*l;fXG2v*qbuWcGAPy+w5E_a%hf7>xP8v2IDGne~ECBP(ZkSwk9QuUI{Y!GJ IG%-{A4NU1-O#lD@ literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/__init__.cpython-39.pyc b/tools/inferences/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d0533cc269d526cf21afc6be4a7d3ed5b1b17fa GIT binary patch literal 237 zcmYe~<>g`kf|lOtX^BAkF^GcpPE-vln4~NrG!;FCqFM8u4W}e5i8I{ zF!9UG*(xTqIJKxaCbKv*D?cVQFVhiXW=u(behyF=W?gYie0*kJW=VX!UP0w84x8Nk Rl+v73JCNgwLB8N%1ON=)Lw^7O literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fc212a6bcade8839065fac9c0d1b842226d9c7b GIT binary patch literal 12891 zcmb7qYiwLecHVvUlg;L%kryOZzKy(F8`%#X0ax>e^?ovJ!@>eQ)os@~n5RPcB4KmJbPLQYZsD-}Bb;wap} zzx!WQMPZ7iFqLT))l%hGvo!hDtA-U(16Vx zwGm}k-POYSDJvyubX(o>o3_&O+hg^}Z?DydZ=}*+9k2$%Hks<6rl?AF$QqK6Xyr(C z*cuM8Bh^vsXqX?X9<#>7{Bi3zD8(vf^@MdIgiTZ@ttpw0SEj2o)(r9qmaLovd=~jG zma1f{bJm=McUMkTPg|#DK3zFeoww#?zNd1wdd@m0^Szbx)dg#zdcnFN%YBtMsu!(` z)l1f;>SgP)g!Nb6tiENvCG!K7E7iBHx2spJtJQa`cdFN{Yt`%4b>LK%sobdEux`k_ zR#~(b*&rMGOtEgVBWxJoTWo}lvZHLQsIg;goQ)T?LWCV>=4Yz4#7?jY)LC}Y?c2Lb zQDKw6r3Pt1V^eJUGj&I`?yxL7$!1aiU8X%yvU9l`nxZ7n&Y8)%TlK~kFBR84b2gtf z7tSr5KW#2nN^7~boOy@Wn~k@Vb9YO5r{+4$Y}S~=O>f;X7aO@eJ|X6``GLdTQoUv_ z%%3x7slr5vn#h8{R=sIfb6aMu?wL*30fmxl7D^SzbUw~I4bLpq%zVAtsFZTGykl;b zymi@2NMk-Z_hCq(zUJjnIfv?vEwf(etY+rCWOB|l&2`UfTseDob8~Y(C&xKo=f$&? zU>5G#yDPVrS05~&!B|oMQLW;*uF0Ke%@W5P*S5@D10%`j)-b+GZquxDGgssea9*88 zy2(pksa8B~y7hv$nd44!j+I=Gm)4rzAtr^xH!<1H3Sd*NW=+#U=fzW z`MAMp^cZ4^*yk|t>VZQX-d`NE04$Z=hLbN9N_kAGR&3^qj#;d4IJ}0rm<@+lOD-*; zi@~wvT%}Ykc{xu44wn63LCz*$k_uVN$p7f;&0R z(H()0EFExZDd*+h&+%NP;#6MhAOw8ULroW>t-3EEp9IqLg(6{G$7?oh7i)l0x=^Yy zyMPt38+Gmh(wD=~i#3~)iepnF%^xZ{H3vPEwjE}pBfFR*p7}}2W2KvUC#(B?Zf?V| zOA8yQZEuvAQ}>4@$XnQ0FYNbOMdW}IKz&RW8Z5K-gd%ey} z+x424tJoLv^$JEfTy#9UP~xs7DG4-nA#dZ{#*O8BVQ2fkbkzH*^^>jNKLZU1dLO`pW;=jORz;??tUF-IG%DwHr zU4>x4Q=RS%mD0$iS?sCC%U(Z=2t}wu6S^=&L_`H>i1<#|1H~ICXS{xDMGz>u_;Vd< zEwSo%y|?(_A$YmCv;iJw{z3zJZ@b~ZCf9M<7U*F(3>J&fu_}=}yOZI~mYxyMs zAU4rjgV*z}Z{&E<%?|pJlIK)iKUs21H7p(UlW(w6-s4?FAlhuOoagvj-SuNmZKK5N zHQr4ay;Sr3D3%>iKi;U8Y)JQ0D^0o0I$c-Zn6uU_`n|V>5l}ewN|OZM zPhzLpAqS#Qp$dzxuRitFjUOp|5lr~WcWzyI^2o*6@C58D*Pq;~R|{K@R_sme;QA&e zjQvY`b`S9+-*mlt6&r~;?jF#7R4XmkFl5hp;??UF5UYg~>B!p68cq*c3!95nN#2=n zZ28CbD}Pmuix{OlfuvXMS9{f@mQ<5!3i*CDszuefn$psMj;TW$zlCafQcQb)5oBWc zCxm+!Nl|HO9)xyB#fBTkhSOVy(Auh}m30WSK7#XJ7{f{?{NhB1we(z*8!w7lsLgbO zK%9+LZi42Ipaof}D&-Z*BH`4URj4({TomV&Q)B!rYC%)xJX-f${wZo1r7D&pQq8$f zU8{>e(18-BmVATCa(#m2CbjRc=AJqbzOW|5!@UQ%OSYi_&dh-id=-V4CK5Tak^*Lw z6lm{eRy&cv4&>PLUu>UyDW`%pzt8?6+$){Xx(IG2)~ zGh01EX=@C8Y3sXCO}k2~S9G@{qL(-YmCyr)br|0#(!lpgP7Vt76>PtVwFb~%H~Pzn z0TE@oTm}7z(i#*QG5A#B4KKfdUQA@DRnT@+Vfv2N z8X#I3rFEnhmGy?iASfQeY=-ea0*Wzl1hnE`=tK)M91=&wz>W_2_@Ff^MuRb-_rDU| zVsu1l9TjQ$9TP{z(RO@SL61kW^?($OQZdlqXS{`ku8LV*R z#>80Ic1*SnOK3YL+m5}`mg)y>$HXxBcvSLv1i3NE(eH}k7b^c7(am7j2Wd!OLS!F7 zZf}goWIxAV>F1a@R_kHn31w43&N+x?JqK3Dq>R1z6P5XCtrh|5Lt$s-hR6YF?XnJiWxC0PKvCUD?}e-pR39PZZa5pwBX^FG*oaMJDyf zcE9b_$wmrJ8|XE2c6%~Q=CWHQ%aF`gSju^_QV&&{&7KY{+8m5tnB0;8O!MHlA_93d zEIMpE4FwrACFtx=)a=VIF$iD3f9GbNPW-(;((BMt$^icVR|ZK@5z0r=mik?(&%dud zgHAqBnm^KfZT_6Exi0i`FHW0M&6eh?JWo6HYVv9YRxe#{uKhNVrhYQY_B0vDfxQ_X ztn*i;x#C_!X0HWbLhY{|CVvF%!I1a|D8eZ0>KM(zK=`+S{YTW`36lRWn*2j*M&~)F zqe4>p$CM#ToDPWj?W4<_O{h5iCD$qIW>ek2E+LPdcm)B&8ash4B*wUa!(^Vhbmq>uoh)b=6?>G*S0WKL#@# zV&dU+gY6*CM7|=4Z95#sSz=k1*O_n0yCF=;iQ&K zxGfDWD=7=_!!UC>85e#eAYsr@thgKyhUrJ?>d73d=atRm_q?)k{4SUW9M=$O4Z@HJ zcZQ#o5Szr#>SCLcG8Oz-2{s^eKDLmNb*y;F4l4M)SV{&{jW`wH`Z^$CCEWi)te+ru z1r;+L6@#w5YQq}rfQEx0a5wIyc^gJ{&ZZ)`jyu{xxhJf;SHVvnRPgsELr6cdPss1= zDj0yQG4%QfBwTf*^k+4>@WDdUPVoC)+r_j8-4{sGssZcKz%Qlut4V!GgH@>;dK|V? zTpN;@A=sLRhFTgdEUgzc;`ogj1Rp~k*rJi7s(l*OPNS3xq4bS|j;yPW;Ww-gtHCd$ zXVeUO=!Ly%{O8{3UbKpTs{PzB_~#hZAq(wY0LU*WEi{IlNOoKmyILunEvfLT*d-El zRTK%)1c?kAeqI`rXvdD8C z_SJA3V=RxtZam7O@I=O74RixK2Iw&s7u`e3u7v>d}kJIi3d?wyFBJg zo`zl4!X zV{iw&tR>FFR(DgISQq+zh=1>zq<9?^hs03(#x7v^E=mVOv~`5(;)s3oh5GC`pxuOJ z8I0-H1}O21kQL-v26oDz%nibYup}wo0>xph$#VO)Oz(K_?vhPA62Qa=@}prMs|yR5 z>cU25Ls+RX$i|8oBY!C@xnpcZ9Fsj(1@e%9V{9~pYyv`7_IOQWu@>yb@-G0QvZH8s z+@qBnYkvdQ^YLA2&A=9xJZlXjCuh4Smzd8FF$dT}{DFH?X+n zxWDJEieaG?RW|-ytHD~sJW1m3N!}&O@Bg~uee=864=q#ne_WWjFPLb3AKw#SM2On2 zkzG2$CO?aNv>)!%EonzZOS#_wkvpn&0z2%4l+}l!)1()cb)SgAc@p1AJc+g@T9d7* z)^uy8by7^VAHhwdhBKTJ6NkzZV){^dTFe|O&xn(U$|t=K#1v*RE2hei*%Wx3^?ni7 z=rnFgv$zY*;r|qy!B{`U+GxCeP*3J)r=Av5Y=8~GDji@`pKAqJ(rpX(zi+WA*sD|Y z+l{nNf$k|bEl#tO&-K=sL#0!4eP(ww*+ZiIgubD?pTr7fTl0|2kHloyPmKD2Je~D^ zNu1qL+cs{R=fpY8YzmyAe&-MReQkGW?hg*_j$g$1vzYHx*yHc2+(o=GzFM0*voqlsbZ&}i z^dk8U{4{ZRpPLuh{C>R=$Y~yThgViaOheO6v9s7Q)6{;y?KI~420ORkel_I62>3qH zX2j9fgcxW$VgmD!9B#cK-k|{oWrzQ>EJL zpnC~>b3tTD7t4K-YL_tf3nD9b&N-4(jJJFgyD5Z#FME(8_Fn6RcpGRA73)0G`%N6l<2InB?^^~jP zJk8px;f%Y)F7N2#GJ6xUe3=b+-xQ~CCxr_Xef%jKs6`;_as@`9C&U~3xWAORiygS> zecX>FZn^_^ejoS0CGPwlj=l985z2?ZE58vbYOSk>#$oUt+2+iC8ye+>kfOoICHBG| z7P!6M41OB8gjfi0NZ$^2^;h=#4&qi<5xr7d*M#=M;NL{~ouG_Wye2j5HTX^bv~}I9 zi|gVl_}XCC!M7W-{LX%PQCx1_+}G3w{{)x;Q2Yea_bfcKD>&If={G}4zX?k6eA_?gZUW;S#SI+Zph@hDNp=gGVA5*>z9A-q9S6uQ^tp+9 zdI%A>;7N*MH|eb(j#zGQi92)(yfB{qvD}l(vISA6ewNU~ zHtE9luR!7_#Ukx(PXym~YZtOjkarY7Any``Zzsn4L`)KAKuOvmJ38pl`4w?VuDjdc zCa&SG|Ec8OB>4l#D{x>fkaysqME8AI4-xlN=N^?_faUFzNfUo7doElvaA3^_m z2~W7J@z77*djFAq>+Z_?_6Lj0w|%2nsn-#*iLJrONsiBXSg*bTqa7~idFfZ!)}~!t zf#c4{xW3r+r`svq_3Y_$3vAx|*xR1{*7R)7&3mP)lXX8br)Mk9hEu7*ubp+Tpa}oA zOCc;tgW{7Cg@6{WSWd2rz9X`QAEhfF9IkP=85kTD?)KFEW=%%Y4xH6+NQ7=`6CpoL z%S3p=s5-ftue%=e>9THk&d1)WAI0spP%8RSx~$q;+Y1P;=ASm|xZ1jBI~?2RFBR+y z3kCScYqnF(*$WF7&Ye3q-$1;S=fDk6OjUCYn_LNjhaOix&#@_B;TsecKn;0e%zN7> z?!pIQ(sj9lt9rd^{z@wdy;7T|i{kd>#f^GN-srz3c1I!6AeLrUoobzLA-+o-M9eM> zQ^WoHQ52*zCPB^^Y!9QQ3GMs( zv*=-$OjhTQQI)$$aJ5I&=0^xXdJY7EF0-#VFc#qaamg7os%(!2BKVaa7!y^tht|x3 znK#$XP4mn(a}C)775pygvI$%Vn6kR&_gBE~8hWJAQ-IhWKGb>^rEEY9;Zb>qrdxbp z*zn*A`eXErzrVijN5k86dcS($Wg_no;!|PN72zierLgT|tIg^foM(uR%ge1xi^QB- zz2rJq%$L%cl9naJA`uTNdh6g+G*~aIFTlVBQ|C5z8!za<6Pf;>W z$qW)-hgRTO${7^g=BEgAhLT^Wgd(||A}N7C>l|`27H$V?<;O!!h)@qggSDEUSVR17 z9bU4her!F5OR~rL-=S6rS;C)_F1c{%AXptl#r=LN*c9)C;;(ciq)GQuM?FKq7qqamMRQE@KT}Bq}K|V82NT$2rQD@ z%8%Z6D6ATI-{KGJ52ZiPPpp!@rl$b>`$X*T5;655=lEK2@=TFhwqle+r~C-v;|{hF zgENu;9u+mO?#F6%uqns*Ff~l&nqEDK8Ln`QQbB&~K!{+2@;=fG5R8T$8 zi@px;y&oqY=HbuvQ=w$aXCBg1%xV}E?MME*MB<+iiTE=-#UQW0AA@g|1l-Ty4Gg%B z=w-)#=KqqK(3$A>%FP_Kf?}>W_`fD_N|Lqu9WML)szc14GuvO$&ou5N^@v8S8{ZD+D8d~$3Dp%3t;nyWRT&87DH1Wow2Dftd1pHlLx zlRkt-1CF3(jM!uLEOCafa6% zJEx9kLN=^DoP>T~$PYSVv5RpETD={ox$KAw8c?UZIPf9|B>OsD(g7&Gg$cWusFm2u zowMR#axif^m3SM8ujYNV;HT<^0-ij<%)^n*{{?k6_}WQF6!||V94*cMnP_8^TD_fH zK|(UGLZ~@n57Yxyz*`QIcPZOe&w&qpFDf8+ft)LDYe_u@;hEy!Q4;k`vl)qWoRS#K-_k9iUp zwE3A4rCjvqMr;p~iV~fEb^Oaz1TQrty%g%x3yzGQ(&+65UT=Ud!3>n(iI;uG0UyO% zkEjt>N7Z4iHy&4~4Dy=?zc~EvXrUSgQBA5dl2S&($df;+j{h{-l{5@F6Tr0z^f^o; zNB{lEPoM`w4_ebq$A6}MI{xEcGj62yH&9Naw_)@)7Sjy!;nUCfsTS8TfBG3eHU5v* zH-_ApI`tYR+Gl(jHKLmFB_aVj@ysC3=zHG`q6#!?Bcg$41SOi$pJd{3V2CIEm;*&9 zFcSPS>S%0KMJtRxj`ivNOk1A%15T?5EA6*wS@)B?j^f@Q8Fq94-NL{7pOAP;S;ea@ zrK~aaxz@&G7JAS^Ke+BxiPgRa3pie(!SRY`2jt>H;Gm5s8E8ok`k27O3|ydffd|LT z0A9z1*$_Hz1e2|Zd?r8+o2Z9J5h4mIDHgQW+uc;ZoyPSzCSrgaz{Eh;K-p;b2)IG< z^6d&j?R_jxd5p7P7^sKf0QxWxzD}<4-@*K3O)`0-QeI2LROXh!QF*4sZ_PnapGJl@7HvVOH<~Qrvg+;ZiXXFWR?pjZU^8nR7D3&emuA{u(Fp#Kl0QS@NAcbe z0d4*tsDP(>xr(gyH&o$|31*aCGJxTSSP$+8WN7m>N~Cz`@2}jIp>r67^nP%jo)6AL zkoY;O^FN6O)|KAi;la1G)&0JpX2)}71X1bs96WC3|Bf0B29<&ru)*bl!s$J)0Qd!} z91Sk5FcnL+I#@65%OObtClE~ddF+Vr*)ON3to$#KSm8iwHSi{QCL9>L5xhLCxctwl zP8w^|#Ov+(z40FZ;jp$?9AcYwo}!?@sZS~sgNsNFf|9>94o zA=3de*gCKy`v}ud2|XLJM&#?n`Mptf7+2CjwlYC;?i7imu+}T+oexs`Axe%=GE50w zb^@zh&u=WqmxuE4owVGgTIPRARcOOW;pz~^eTnqrSA(HkLoLOn9c3u-J~}Dm>Z}w` sh$zkjI+auMM|t_fxfcBBWKQ7>lCfzcJ~|eQmy~ZS-;S;ftLe=D1Bj7Tr~m)} literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd13e731299764d78bdd5f3a561714411798cc04 GIT binary patch literal 12529 zcmb_?U2Ggll3rDH^$(j(vPm`nMXB^diKL{K$eGdTZ=|94BXLJV&5G1&rbgOwv$MKc z&F-#NRy8H{X5XC6UGLgBd9{aox!W1LX#;yVcV~DX?(Dv(Q>9TXbwsmvGP!5 z$Q)|Z8?KCKimFsb%~1)7m&YpO=6D->xH4fLY3Gkt4D(n!f80C{s)_Q6%1QHN8#Y-v zWu9*5rz)o;)oF7^Lb}RlDre2JZR|Oe%p#v;-R1L@3+4q0PnBmYbLO1P_mtnLylK8E z^S$M_Di_U*GT&EzyK>3ARJm+kmfz{}mC9A~YUP@Ft@4ifj)e7>uUBrEH)MXG{BC95 zoUhz8Z&nt}h03D2SXnZcB<@goxw337%e+=5er4F;_Z9ON8)CzF-)19hl#Q|RqQ(xh z2{uvG3Q=~19sRy)t}ufgLz!=}Np^yr1mr!Yttr!|a_?%2k~}|SBxe@f+B5DHH+^F& zKW)rin43Lo%$J>w+(yp0#oc=CYI5d|leep$&5U}L+1&6qZDYQc%j4C?oHf2Gp>PhxYtIa~&{@pL`N`yrVHlgfU%PVt{F5h7vN_q$tjmk% z%i$=z^LJJjmR8r6&Y`a;f4^F`JnJO))Q)^kPMD7xD=uVO4l&E^%yqbc;z zIhLF$I~B*z`4Vt2?FSQbKKYVV%IX*D`Imq9<(FSht3h<70;0(1{$p0bJI3v@n`6jz zaXZI3nkDq`;sLi7bAIlF9M6@@cKIcQQUN}aFBI`gdbVG$SsoaVuik=FWmW;xVAWji z1Ja*Ee~VR%lCo`4B`p{#+Ep9moQBOTP_>FVVpNc%Jf^pvx2GdPzn9y#EoW{UrLAp; z*={f*LH^wKrn3*Uyj;N!E=r6nWmW2B->Fe+Zk0g_z%C0yZ55q@wdr!F;a2@z*?J@I zmeIqJqU~D+hkL%|`CyG@SMx3dyQiZ;YGrkK>7KQ;dTV8MDTuy*cX8=XkXU;E=F;Ng z(!C%xe}DZhunU0S3%V9pmY46ZK@31uC(+MCbMOCj@`p^L98<3;etbuf z`3Uj}AKIVkEV`@hsD8vswjnVVZ$rBMNGZjChsA;Ik!f#r%16>e}F~N4`nB!D%BzXP$}aNP&tAXG?sV{xF3NA>sk%E-=AMw2M6aD7r~&i zo6nUuUC-Z#e{lESdZ1$|y~l}HLGr=fdv_MCwH0&eB>^BTv1*OGc`wj&yy#6222sbi zD_)RvJg15YgDMJi=Hz`&+i4K1*I3TC1I_hVTTol3G$tsww0Lw3r%G6Ix2`1*}&aQu!5>%hQtI`;iF8@hsr+?jR{DP0d$?x}!pr zFF=$dO@O)qE{2~(Xwr*FyY zFu6?i2P(NO8x!0viN)`&0q&9MrKmc1Go$h)jhEwGOca z`Zpwco5NywA9F-27{z-`jJ4~J zfoB967vpWnxX83JVvM9t3?es9xfkl=Ws!Li;WUoJA|no?&V-1F32^Mlk5n;)_ffou z@izRy($Ejpg37vB^7{#qS(M{r-Osh=F^N-vOTC6WzF)6L;s!hFo$v=+!{BYNh%>Ad zq~O=C-aPpY@X2q0pZW&)>2H8fiR0oVQ^3>Aj`{54z{BJ%75RcTJEN@L0&ZH=+CnEt8q znbZdll#f8SIsF^cj}x8w5OzViKwDow^u)mGOHxFQ6t@d7H9R|8sukq|Z46kxOWS*a z=PI@Z#bHb}PE3<|=T*ppAu|h>U7oC}R|-u{pA9Xr9E>!Wh>`$IN#Ew%xiS&Rqhir! z;Z_y7ckAXW8$4-A+T;(>>dzo4VvB#2YN|hy+VLmaV|^`D_fiFYt_525LZEpbv|Sq7 z+o@Wl!?EjNJRMppd94gnj4bc#{{&S37apZCNj6((6@`QFve)GH>Rm@>uROnsvi~og zT!Ca}h%nsr(K_O>8zm7KLrp4c-)}sxZ!$Rp=k|!%zsOX(rEMiF5#;OH(ddBP} z(Z=QuiDGri2iMZuD z8O1^0f#oYph15Cyq55gX0!cFU5dY0M{cD>xh!!lGt>2P7;R#P_-sbu7$* ziEl|3owdRuW~$;a1`eX@&Qm(@w8;fs#K^F8rh~GCSbn8u4t7A>JtJ@r4rF-?CTC8T z!hGPIa`4^PF1lABNFEdj_C`p^plhE{(AiYj0a;?`^&UvL8Y+uvO^SpR0l9vH{?~Mb z3Gwh6NEEdf_6_|9)Py>s8d@fzBbU|^>X3vFsYwkGJi3~Jb(q8(sg_2~G)gArLsnu8 zm_+-3SjAr(0Y&-OGN_wD`Ls5MnrZc<#y>&pd0NbSzZGN#@JOqv1^Wrs)cOO_S6aqxwskk#)G2w1B)l>H2)4)LF{$UY`h8ieJfmWGHD zi^^lqF!Cd?dSb96qO`l=8F(h z#Ri^hMI2I_$*(I}BwMuGzqn81ZP2(VXvE`i47eHtnMifs8RYQZWlU2DA)&WfLE zUEfuHY+%NE*ddmAu0D)3d-3mkps+z5o3A{Q4$LYqH1GBWSRM5L2#Ne7E+(g4Z|8=k`xy~aR3sr z4EO{&`CtnkVq@ZvtohdhdCch{Hr|H(H$cV#$y9X~Yi7hi=@uX=JB)gRat0<^w~3Y* zBnuwB90*4bYkVN=_nt(;YNe8g=vEv*z)cJ#Sc z)x?18aZl@xW{DznnSV&u+MQ0IELDH@jm`pl&JlXY+G_qCj1WwL#LV@ z6*a$2HF3g;Kl=AL`wf$&PSTHbQl_Uyv9L8-clVm3&9QK0f&$i_7#2fgSd(aZ9BXns zlv_yQX)%nJr@~$#H!Vhy&WJIjXT&(tv!8WSulzMJD#k<-_J22MuH(6fXO^9No_H8* z9&S!Fk2H@qjpi|Nq;;Rgu;!#JHIInHuY4aCM_>6qDvVdY4RP$1?_>UV1UTM2E{>EQ z&`6sn{13%(mc{AlIL<&P@l3Mw=*vftgebJfemR*V{cuVgVS{WCR`ej7{DD>=SC6@? zv_58&;sl(*`0qxWlb}1vvarlAJdZR_zw$LHWnp$llP##FhxBf*l2ypgiRKjK`H`4t zw-ci_AhFZ_cg6IM+Ol?)=8TxZ$R@!VYIo|O-PhJs)cf?6HT5C-KZo%i30r)j{`f-~hj4z_UhUCgw+1EGG*BHM7{NIyrqs_grpGIdtTRHX?dbL66SNd$o*?GOs z#(ryO*v0)m{DC-%RwT!PKT15_=jS>0_I|nF!^+8vV_%&UaTL08l3l_IJxcZW>mJ2e zv+Q!aew^!V&Livq>0-oFoGE)+wm6J&NFFz{B1`Qz&%egqnD<2|cVENDy|GTbUnmu! zH?BNaU+6zZ&;xC6g6;*mg0G4Ta*as0Auta5e+~Uym#c49u1vL58pq0OLukdeAw{hB z=B$`)uasGux6(QG4rVUS|3=J~9z2XT=bCRc-z41&o;hMpydmBs$pLnYU4I^Hz9nX< zR;euB%3y_qURS9i&d}KXDt6Qx?A@J+m}T>j=~*`DuZl_9HMg`E39hrjY7#Op*TyJx zlgRGlAR)k=@4y}1$6b@SqaC<2`?wD!?#v#J-TX94`SyR|(`Zp^Ui?ZweoNLlyev=G`-&e(K^O~dsn+5g3YwQs8?rHJ%t|#A^x4j*p ztK^Am;qD@6W%_vP>^57$JWh1b`gxny&p}J>k^6h)w}9CmSHuK-LK9dc6YM?cg$e%& z;7`RwxcUI0z3G{l*n^07;7^KSWks6Tv6gO#>+toAHQxo@hM>KAUfgJXPt231f6?{$ z@8tTtA!`tIYUd7G5TqGfO-TO)zGkHrb<_WiR*t>T4vA6NR;!_Hbzp_PjPH{5;lM# z`w!L=@@t>_Jsof%LPC|cE+(OUVHXU1U-WFF$Px*g8e7F9A^q>Lq zjsghiOAOwf7=K6H5)(AL;wrp4dluOwa^Gk7pC=^mhF;Sr-Cn6vpS*}X{u8M?zC+wb z-R>9N-ltMZen4{LXIlSI+C_u#zC380P8#^F>^Y^4(CICMST_9S2E6C8#Cnig_~5>^ zaA)NM>pSyH%Yj}jyDow|@eMd_$&q^wmVcn*XtNc%Rn#nZ0eXpO)t3!z(u4CMVL!!z z&~mpLDIAim$qREV>p%4yQy))G<-EM_RP1T*kuf<{wzuta6$i;_?+QLE2)a?MOVXf& zvqT}{g>$N%tDqNAStE$i;Tf*v1pGTV9dfU6@?O0vLv9C-czA-^E_eg6K@8A9a6+%x zxoQxBUm+kjiSFA^{na3bQ-8rJ1~EG2TjWlt)qzb_ay5&bH1hPF#2MJPEec%(y64+9 zl#nOryx%x>2i_fnJP2htH{6Qxy=EA1r8;#FpN+TYw_Qj2K)xYRM{&?FmS$A!ip!rN zs7u^H%&r}%HPXqC?uglWPng(#bVIPA)8> zL{k&mPa=;~ue!RXBb=gMa$@1%K>X%&L>qk@lS-%iuef~r5%`fVpT<}yR$uGcqA&sP zLh^t;_neY5?0uU2B{i6Jr4LnUNNBm5T%0tLLHG_Dc@E*9M+F3 z2Uelx9RzLLya3S{{{yNOl@+jzXkF~p>n_&ea z;|9%1dlG3zJLaS@((&=iY&YT`fA`^&hu$Mj-d!_F5@Sj4;+7RepojwXX+zIeYHtS7 z&3YwQC7n>^?vo%P1JW2!LN|om)YCGw0=epYz|OS_q}Wg*iaW!=kt=RH#*?8zxdw^Ij@Lu1e@F9zZ56l#}!<1wPA;A0*`xm(#^RM9t6-o1my0s(V3#2vG$dBP`8S&<%1<9xNjTsw4h> zH|SnQk=vna4tmz4^4|N5c<{ACZ8Ep+^G%{MLR6v~b*IeCu0^L%sMCcB42-;N@da=i z3L%KyvnkvwmwJ#`ch_-u2ADrIkyRTqFhw*o* zL==IYhzZlphFvmn5vN7v8IYJe6Gp z$8m|IX6JnVPpMo?-k`8K1wu&WAU?avTl5j}>Ki0t+ZMk;-|==^K`N{gD$gJS?HweD zY55$aF-6iyDuhH|EGrg~|qGs3D+;ckupE5iMpa4Fe{RmkNL@Q07zslp*{<9Y}}M;o-5HtW?b8lnV| zk3UFmN`Bb4q;OHvQyg=M7bUFp8y_o z6R+E|hKL49rj%v=J5;M1ApykSA@FvM)*6{0K@yR*F#ZHJ>o7ReX(5N93dr+b=z{$4 zDH$Qa7EF6UfXIBc-kwbU6!iHsO1?)4xu&_Gq=|&Q^>D6}?yA?k{qO@ChTDBRbh1m4 zG`qt%iQItA?hVh{X5nsZ_IF?`DXR!iV3oMIeq_NS#J(8xw?|GJEY>6TL^Iv7m&=BD zojTO%f)73OA<6zuH-899JOg(;4Akt}%Uv)N`8tw=y_U80(ZP)c`Th(5(ODE83{>hzVt)2ofCNuAmn(?+xb z{Rmlj*TzrjO_+|Xz>a0^f>^F=I?K^zF9Mo$OSF~DeSax}8v#FezV3>?LZ#=sNc zzRY$dQ7fr(n#tu?LOc|!CzEsmk4Mp5T=`M!(0+6W4wqGM^OD@3@% z45yTiClco1epeZNhcKB#$mfWlyJpeV^&kRkOCBQSMSy8l23=vX(l#KMl+^lWW^e6u zt`D<&4+ismJ%CIitFa{teeN{`GRg$60vlTn8>PPl1xdHPNew=Cx7ElXNC{8xYn zao5YrU@fVC{vRp%OM=I6b&aCJAJRtxhtYDlHmoqdYw`a?2)*Qy;ewOtrnKvVSnXML zyHb{h5~o-m%=LN&H@>uLEr@)uaz}=?IdL1iKV4DF(hapNwnb?Z^8ZeCx^M{u|vbdVuy7>OwDh2iVJrGm%4zF@Xcui~ulR>^qk$ZPA%W_t+LD!3h9Q~*xp|B71Z#mv`np*S0!)8+1JCd0DQye4~kud*B=wgyaBssr4<*9w!D_W&I9 zht)e7N7BOks7`N)40S9lj5J~TDWUz}9F=$cvU|Pi*o;FPDic=cINf4eVBeL|I;Sfp z(n^$8sI)Yt@hc5Cd2vtPU6TP4PsNI_2F>HL%O kBy~l9Dj~mcWpM9LN;w5Z1gxGIPs9(upZG-iM7jKb0j+EH&j0`b literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc b/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dba0b9d99637ba72bb725cb0b100f1645e2decb GIT binary patch literal 13012 zcmb7rYj7M_c3yW+&tNbZ3|@mL@o17D2|(oVA&QbniXsSr6eJ4d5}+uJNOuNvdj>t2 z7jO3fa-HOYO9Myc2Sn; zU)4sHU3FIr>66){S9{KLg_R4o(wjb|kd7zTcre)b!d9X5=9Sr#l zRfaW1RVpLd5eexik5X!d9bo30$o9uM;; zD#s+(ne42Dbe88TFJ)gsKEb-mCqa1%`6TNuXDX+&rzJd9K2tfHJuCA)<#UynvoFhh zZ~1)XmFz1r-&ej+c{TfL<+bc<^1HwMdgWsFV&ziyQsr{?vV;wk->6*4UXl59`OV7J z?A6MAcD`~gd#$pNU8pQ(7lBjRV0o#slwFc}t&H)quC}i4Dr|@if2{1N*&A$xjee|T z-(q9z5F2L`MU73eDK=Hq3Q=aT!yl{JWp;!e1oI88cm@nJwx%HfJgV!4MONqIgcHXKw7Bd=EW^u#au#EY7 zE{|7;IcdCSamTJzjdPhZ#tac02~kI8!QgSNVN`ODjcUy`8jb}Hwqq3RvSnD0@>bn7 z?5dHkRqACsSIt|-w(V}nT0$O~#M}oVhuXTE191-I^~Xl7&`xIL+(cr|FpLe?tzSHS zdV702lauYt)Ohi9Ip~FR`sVV&(#q=6DYO;zcdKQ~aSU!fY}g!qTz_oj>S#$mw~qFe zbK6FZ8@VF4fOBiq(rs?LcC~oYaB2m2JIAfW9J3vl+v^SYnL!GhZ=kd71Q@1V)i^T0 zYAmlFF|N(8F0YRf|S=e{zfhu=MTJt-M{Z^XOEy*vJ(vqgdOr zcolsy>K3oq4o#th#<9d)*{;}b&Xs_JX+M~d(}|~~QdWMbo__i_pMCb(tm;RXD`1L@ z?mlD{ykp!d*K!QGPHyEmN410=UOX7B#hjbFo#VN3*(yJcfFbY+7nBZKU2&d5P6-sr z7m9>&EVofN9n1y3QU$xp%mQY{tk<{;NM8(~mE`du?fPap{hqoWHwv3)ls~@A#dI%h#{pU4>+Tt4^yS z4`sgprz01|^9~;8`$&Y+Qg)R^<-zEE&C|SytBQ!L?Pz>dXl`Un;iGOuK#8|>0hO*Q z5Ai0*W#pLZMoY0Rm1menIi_A#+>RYZ<|D|*U2F(UXVG14M|C4+B80?PM+oV1Bc&vN zn{@!&Ez?xB+wCd!w)%Dzf&ou7-D&ztA(vtun;K_sKZ~Lr%1%TDW>s+qN@=&BC<$87 zSo}$dneANhyWX8&UBej8FD_z;%C&s1yis%9efaHLch-Cz>+2m(qva>=-MVvg(Og~5 zEJVls@U3L8!rW;T{Ua#0Dqg2f@xpxMuwjViRdq%zC3Y7r7o>pO2{>W}>+Kbh=R zwAVR~q*v=#d)0)7pPE$r@r!9OHLfPL6rf2}j|_dLtNaa+=1Fbr{n0Y$z_WnIxrwBx zc$%vSbw|Z=7{PLgc)HM9s;iYCQX(>n{a5HCN;-UFXN8UPM3c)R1}adq2tlxy>MQ&( zczy>uirEz@RanZMR<%)q(t%{eu$x#_#?OKZdMxK+n%rD|6I8lg!Nf!>IcL-HyXXZC zC1I-V>+~&WD@ZO=`Tk07(}IYElu#hfD&P)TbUH?74r9TW@$u9^B3ouF7z|s1Ze4r& z7Xap^ZHNs0oy60HhcvJAUy!=!GmXn9>VrS`5}&BeKe83C3s2I;__QKI-3`Xg@VZ4e z(>5dgysNjO0(RlBBD#lR1&S`w4W*#BVs1w(j;B)q>UNY8yBI@}B<#a9mBLtdwUQ{S z5ynd~-RluCuh;Ds$sN_}6Hyk05?B>|EH&yYyEy&_3!X$Lms zs4}W}!(zD9&pOf9UlGHt0oIAO4vUmGB1S^^t{4g7Bao~Ckrur>5lC0b8x^Bs0DbtL z=;7Zx;4DVLcMM}OhW8G$kTp z3Ug3e&P$@GhYCo{Q~%$?D3z7X>pt>!DxI{W4#4zmg1#7zoOl6tSj;kNvMbkgi`$6AM2Uu#^b;w4%! zs&Y^H09?G6e)=VB)AGHoZY#nN$Hj~|A!fy#h!#4ar}|f(k|Jy-xmAFAcdSgQUgV@F zulxE+4HgoqvF1I~t&zD>;JJ!rLWz;_>Qu-+BO?-aYo2Vz1N_Orw9LW!gzYPf!A5p1 zzLhIe**wY=Ef#FIkvli8J!7U&{bZ^P;Q8NaBt;l14`QDBb*Z1fsXc^_Jy3$*(R?j) z#@8GN`nDI_h+Q?Mu_$}grY94Z%dj!&NO9%c#G2|!E6rnMjt16U&;uuPSsDq>6=e2G z@GGGH!e;VYz#cS--^M3Qu&%b&95jUA1@_mezJip~G)XFvP$FRYnHotq}*3DaY%xgF9oWq3rdL0I}ua|0e)sK|H zAzJ5T{6s3bN4`!q`1+>x7$#K(=7Zk>Lm49EW_zTxkg?%dFvJUdp1Mc-BwwKBM4J|e zS=G5D&xkOr(cD_G2xA#LXSGoCV{}Hyxw9v;DUxc~7%-@1V^WeQoiOD$zeoI1TecgB zi#%KKH3BBk$WRt=ppa6(C%%3HXRBJdK{mbLb6~hjs(@P?fol2*wk_+$YVi|c=aA3F zZQF!lWj%r&81^-r3=RR-L+Uvu1;P4Dx|&+mC|z4=|R9yrbr)Ks!F=mbRBph8m) zcURaD`O%AcXRzI9hyv2SHb_9}cZ4m>Mr{*wMc)B!+%^M}Y}W>zkR;T}CMz~$7+rMB zeniK+J>H!($dV|VZmUK@Ot(_c4z@wVW&-SYlbyNpX?l!KKyJ_k!BULUbHjjVVo6>DLcKbV!|4 zwU5#fd`IH=?!}MZgdb89s;(gq`ZV(Rl4&nYS)wANHh_?EOXYh(O>{7DkC zH?GoeL>p0qUkrTdN43T@GM1BS{3Gqh{W||98k#2)Z|{$+;C?&=bN&(rAk5!`(;>`Z zoDTKybO@7=PKVfXU>L&y)TP0p3!NZzI)urrnz{=UuoT--k&hwQfhV4ZGKW#9nw@Tf zuy{I$vF*cjQA=IU12~d>+C;EMW$9+?Byb6vm zxLfnPfZ=^j`XFLnAB%`SbKw*9;j4gl6PD4eE^dJne;2ZaJR5>R)gyDmI8ZD}j*H;f z4~e)A_yjpYVL%SBF)<)Z+5&mZ=>R(vLL5L20g|q2Eas&}f9VDwDjP?+K{*2xt+$Am z7z|7*(I504$BO=--Fp&=mbE}RAm5-uE0$> zG}9kMacCW;i zO?Z>ulxKK{y(8iX+OQ@jTX$IuYfj1%jtS$>e2{68Mh$sFm26JnALvOyT#gKYY%T7mps z*+%TI;v`iCmHJTi=Q2aOFK+20q29J$l7Kj+b(2tN({x_leq2nT&J z+WXSVvsch^M(K0?GUN=t*e_$hwhX(l--d#iLM?JufS;my*dOyV?A3ib6Fs!V;pgW~ zAkcx4JBL*~MdkO)PNA=78R@s6ufH6Q!YF2Eyj7ra$8op2RivYHFc02YahB@$&b=@L z5$7B2Gw?b1Uu(y^8zmdM<@G1(C)zh452(8YzAwwweo?%P^>hg>Tf_Tuu-ZYfiGI96 zP`P5y%GIov%7?I8LkJyiLr4)i!aFa{hpXzm=qy!4hFb4dv2R{sZ|+3Id3F`^f1c*H zbc)Smc6FW-=S%nQN4;0P3*M`w!!fR#;uUd0yh;)XY>i!e67yaYw7;Ax)x~RRXbkT> z8+6|l$7zS%Bt9c-uo}hskn3m^`b?bN#{ml5xi%bJ0bx1t0dA@dm)XZ%leo+tjxGFZ zl=9(k@mHfo&3pa1cHEa`X7vSsyChf24LO#-h7o#*v5;@?lDH%tW;5Ppmy65dbzmL#7G~fL`F(N!`-(X4y(#&? zzCrmv17<+JyIXF!L&`ViW^dQ%EH%WNpxeT;jptFYD+!XA$9S4-Sz8?~v}Zmw=`0epX2EI>9F#kJN~#NsY?vQMH9r{wCrCN&aqr+QXU!$+iBTR(%f zF@dkM(hgg4|2DMCEjA!VVV&I$tg{0N)en@}u& zi)_8O$sUx@kHy<)S!flq_q=z+71CnvSA%!!XQks1kT(=Sz+YnUZpXM^6K{(Nnsadl zE~7o0ZW_6-vwKhC7(4yp?hD!`>U_6d+rWDIuGB2=(`ch?*C$cue@WT-Ig%zf-TFJy zvKm~4wFW292h9@+1AohVetIK_9>@@(g}{S>FodL7^OFm=@0trYmv5Wz%`d@6RL$qh zh?8}!!*@;o>@zq!_&SbKn}Kgf&G0RNo+1|bSyMaaz$X&79^vqB);60-9Nf)gXU?&V z`^at1e0Xdo=j2_xV$C`ajAJupYs)HEaf+RFF5Gx9L@tj}yV!oLJkIXUo~Aei+X#A96x6GQHS=R8bD5e0L| zTQ;UN#{%K_To)NR3pqs&eyD!fj668~VYB~%am|8rhJxgdbO##E!IK6#JPZfHbqWgw z!_w@+tk^D~a8osPBa_LXkq?O@r^G?h99lODM&8&kwvAI)jCEuS^x=0&e_i0~!EmbU zet#J=Qblhl+7}?2BhQqc!PjiSO#1fVuaGVg^1Sqgr4M`pMO2T!zj)J+h0c-GKDqQ& zAtV?Ex)JfCC>d*ZFEls;UweHTqduK95QWPexR=bTwSCcOMhTaBDjjZp)tU|87u;nv zZo*;Oudrx+I~*yH_haqQ zDl*q|Zhk{XHfIz3cZgGR4<*B$sr|3O;RW|rP>{UJAhMkQAzJyPYhZ9gtA17Kj9Cl?UkM9{ARnWB3D zhYc+7HVs7XaJuD!m@qp$`}2GnjpoNFq4*|0fy9qMNAWqz>C^;H5juXFl3$^uo01GA z(ocF0Ie9C>48+2Zhw2$&MFfkh@aC>#;cdX5yXkjqg|pI`MX-DuzoK@nurX6ml-xF0O2lK9O_A*Qs^VyBw!z zC5k(KoET979wTg-n0JAloyU-o($#VI;g;WZ9VBlB>fG;MmD+yqGvdI53{}f&4VPDl z3u!<? zO3eN)F;mxggWhp+V#-aTgL0^pk4U&xS+^MOK=D7OPtC3Q9n~6!EywsML6W(KTMO=V zASSO{Skrz2>kC^OLL!de9bPk{p@t(}zq?|&e3(k=iMc zBz_$BjwyNmSAzXH!8+OA+GulZ*9nu%m+ktVxJnCwf1BW4H0I%iBAP?;DvH@1fcVCzXE< z{Q1vQ@`RFKpyU@R`35CqB4B5rKmn<_29*C^m<6>Xgs>F^45Y~VT|qoZngvicj=9AG z8^P~u!ml#NQRSWo-03u#jQ47Ia8 z$`Aw}f|~l;BN+iG{um?YU=aMyz1$h-*anx9F1JzskBO(6_tk=*tQ88l2#3o!Fdq0H z5@PU$?V30Sg5TS|zniphd)t?QgruK`P;tLAus$d_tgEP{=OE^?7iW@EF9*Z#!1zkl zo1oirvdds0tpZZO*?1m?G>InmO-eP}4uwR1CW6tG;X5LgJ7{2?{sVMTMq>DRA{7d7%u-%j@}8p6NrMOB15X7rnO$=DP}@37)@Ri(taHKQ6h%8OGJ+(fJ3<- z=`qU1ejL0Bl8h0res%J*WHgAtAO$@|XQ+lCHlyRZ4MlU}RHHhE7*0%&t7C}B^lOud z&BWD=dNM+F?ET`B2U?>CQJv!`7x17MQ4(b*|0dCyKx8Ko(Pe+YUpoTK2(=!iDXwz_ zW%NkEfm$>9BkiNf|K2+l*T>WgXvKIWg`SU~R!tA$MB}Rd*IHb|7|>7uYyAgW-#BvP z>huekSfBn`8m&$sRs~ubm&o2bjXqJ|b&VkOi%A+ex@zqIq`TvwNXd4N0ZX)D577%6 zsb2Lu`Y;6C7=9_by!Dy(Zzuna{!yYCmsh*=AIo{Z{!GYIMeY&f*isLDx>cJCUXd3^$1)noQC=45D1$Q^_KXJYt$B#6c1MmBe)wdKcgEG`Et z$CPFkSolPxmmIQ)5G3CS3`VlN_ya;kp)zp;Ppj7b$nE8u@&*8>QO>61ZNm&*Hq2m8 z!Iwwp~zf{e7z$sEu21_0R~Y!{dm zn&drjcqJJYGw&gAUqIbHOk`{GK0>AfWT0nZg7g!HF2_j2AUi6rS!VWH)wWAZyDOUx zigTR9F$u5{%cz|XQF$^lr4cI)RcSj1#$I1QO?P_2!ID=fk|;hT&PbG&qDDJCl53lNtDDvA zmu^*468Dxj31_|Q*z(Hu+DNi9L386QFo_W$Kpu`0AZy359V77%Bk_*{f(+uoh~OXz z5+gv?&Q88_tGd}5%_Kj}M%~9b_n!MY_uO;Oxy*KTB^3M({Q5`vKlwmW{)jT|e{p0k zC_ImWIkh znW2E*aCt;iRHZzc8I_P&X{>xCb0olymnSkugY>boo;e<*Ph?JjYP@u^JeipcV5iEb zGcN_{sq$$_bviR6AswZc%d?r;0DA@{(?};+XX$KtE;A?L$RObD{ip=4}b?C3B^`lvye-XO_z=nH7oaFRhkWGpjPKl`voCV&m$8!UovjhswU1xyFXr@P|rf zjg7ETHpY$=G&asA*hE3gN7zwz>_av44%69jlzEq(Vkg-oAlI38TbVwcy{sup;_Qr` zm|3dU9&@X(?dVgvX?=cf{@kp-ShBXVTUq@ouhwgCC1$Q$IkRG$Os`j%$#rMj)E8^n z9HIa-tG{P*+p1Rd`ShGVMFma-s1ws5@VHvn%h|_zrRwN)+XMy6*7H`$)XhgZv*zel zMbA~swUU*sXws5;Vg>R9_!V7doexhBoZ^au5UX|?fluZ z4<9~EXJtRrRbDt-@<(Bxy}q`zvVMEz4El=lcPb^*wsmeks9PLk+4_xmu{{vuWhc~Tt{wIUtE7* ze|K$tc~&qx-ivL7zU*~BvvE$cs0&p!L}&p!KXTJ<7pWe`Oo z>^xv)#8GaRs#%6q2RE~vqgg@^FCWg$S7HreV$RqO`GVF|+E8NRTtXyKS|AhMmou-h#x)QbxI6a;zG)R;@5_0jA3UQyT>< zZ){h&wO6e;*^==}u3AD5M+&B6pj9 zjho9W*S+}4jVmk5%PSjRa`DdQO<@ zSLOGZMk%IVP@LGlBGVzH;|@*%7G{wHZC`alMk0VjSuB8bI-z2cf0xC8?UG@tGUjv_ zdm6n53c-MR(hAyTzQGRH#e>?8@JaoE6)f3W{Fm6 zyqdGUaF!SB=|L}IIcC}R5|(XMuwW2HUYJ=qhtsj@Me8+|bxcpI+Fs18>{`59;Yq@T ztcv4BvDkom@mkq3!0J0*ic$}&yu`qxd!`p*=2pGn^%QD#+rWl1%8;Q}#q<*Gg`6tS zZDWz^EbAqixeKWUZmUxZD7FJ$*2*<6V%M?{E7M^w0lqhao<(;cjS9A>?s)3%CklTP zo%_+PrStdh*bvY6Fsx$r-kplISh30&$vvl9EdgA~n;e{Jwk3^v1)^A%GD_7-0i5e# zOgS@Md+eQP)%en4ucKvq5`m)isC{aWme3MvLQNvwr$yDM8rPC)3b2$mr1G~=E=QZc z`A5tg!?%RbzK)=vxSFE~bzcQj&w;5!Zdhmy)zOMz4G|iJx+20O%0Lj~Fob&YM3c;& zL@AVJA%cK;!t4APXnq1o!>qF84s6H{vr;cZbYO*}H~`EF<7ZI{(kSa-x}9up2c^PR z850~SXYC!^>!b*BM#5Cf3sYXsx*wdU`hDf>j)|oXN~@Q&+v9YJ86@~*)e4VrS~$$ z4B{ISNvB67Nx1Ao?Hg4_6?a$+7yDSpPUxx1XT@-%pLL+O!y@I5h>;d%Rg46f5pY+( z7!W=CA#m7#bVtRg=-<&$`?o|lg@-MOQP3a5oQxqpB964EeODX_s2ytQQXo0Z(ABy4%@=oQBMxr)A-T}G z&Nxdx(cF^~rvTUW0&cQZFC}q9ZS_t$LyZxvcejYK=%|7aD|irgPk#ygr7wX`eF=Q} zOW-pYafts7F)2a9Prm#h-k$bhbIKeY z)4DhBPTFi(@J_c#lxBiU8+bY~7m@0>7qy{1W7$rFf)}I7VHW4KR9?7>Q_t(j^Jo+%yD_1w>vFu*B1{K>27pqpq3za}2Qsbl$h057SUYJ_& z!aL?;s9EJIGrbs8XE2)MbxWS|R~V{!o?k*Id%FSF38NPFlX#eW=RIQa-h|0;=I9*hASa_LP?h21n_*Wf{1+k%m*?TrQ0LsrswmU)a(A-uJk&giE!L1sH>#2(4B~qeuIXb?#`eiGSh=~XVo&b z6BO-jgZL%AShI%_%fMPu-X|NkjG##8Hr<*0VR^~2#TaH6!?GUH?XF3?gC^OpI?&dP zgcuGK*TFVu&`kn&;r5v`pkHTYDXbMPWrz9hplGW=;;?|%94;ZfjuxTkn({j!OANi( z0|{4sIWw(Ec9HBMx1QJgf_TY*b?m>1K+#fA;po?=#?>KJ4-Kf{P#EdB+Jhg(gdI{7 z8ld>X>HyT>1Y!!Q3`#*6ZbG0Whe$o9x&-w@D9Xr%NpTK-<0^jIh}!%`L6?4@rzNG* zJ*n}(h92ffRy6(1~^^@NwWfP%8%fQKG^UB08iTXz1~XbF8RKn94ezz^>UU1EsgvL8aWdh@<3j z=Y(^T#^jtVPU1Vox)7cgah7_bBK49pRh(wsPgGn+dSs7JHfBf}ry7kJ)+?0nDId$D z-+8$)3+iW}qINjx##wxG$e)7>o^H$^C{DWg3i7YwdyVxy(F(Y=yNNFpx6hFSOt87te#Lh``MW2mo@)DAdNNc zXGa3a9|Lj(kbz2=Mco0>SG)>{%EnP|P_Dp4;~LQtgT4YL`ux%3a? zzl>1|y;!4ltPsz2J`LILIq!%*p%heh^odr{M4#+wSL0nddbN1{rxoYMx3L4iN8o;ZsBpA<)ncWK1#r1MpAlAXcD@FcE( zr|_L-Y4quRa6$yqqg76(NItwIjWpfo;UlVXy|-$4Xdj<4b} za0ZJ0*(V`)>bcx$$qRG)nruNW-lI6!CF|huNp~9jd|yli?L?^!aO{lp6JlmxZ5Rho z?ZwL&*(7F$+MPaZ_k}$bu|I0xQ>}5#qEAQJIgI=$c<^iDXcL3cXX&cV=F#&re!rfo z-}$Oarjs-y=|Qs-BYG7hit?`kza?Wmm}7r5di!V^IrbX*&M5rc$aJ}e zFOE!awU%ZJtv>v|IEGf_x&VKSR-!fQbL{n2xgTPTra1ojl@rGxQ717{%{@l-TXm0N ztmha>I2vn=&jjlOuMK!GPcx5ea#y23*KU6$+;id_weQZqum&+(Z(oBiz+Z01JKIGI z(&f!3>eKLdijafF#t!X%#3kAd?kh-@#Vbg?ihfnl=4(y@+>@v7x zf!4ivhAm>{DBisnabI`eaNi^ek2$Z3*Toy+P2x{rZFc2J)I~X}RdmGp0mu_~ zfeku0#A!NzVJ3<5bv9T@VE4#fHVSDc&b4qwiJNc3!MG6A`zeV#)`m;BaNn1>bQ8yx zK8jE}_$_}FDQNCnpX~Ns|D44TkfK(n9B}+4>MG6#_rqVZOClx{w?>SvnwtNTH%N68dl(v%wK5bFN+0t zQPP3-gZjS)Ouvk~8jphWQ^r`E=9$t_d@L4GZV%tr@d^LT5|@y6@%xo?98tss%xx3cT@&mEq|Jo$mjM6EV#43IfLxLC zXQBxaS6b3%3A=DvEWwO6=B|M5UlA+d%~i47*cYpGqC8DJxF>h^vXoauo!VJP3m=l? zZG01ZV*p&Ne7 z*BHgS-y==%I%$g%`fYI?JqvV3_MUq~TqX(U{Inl8{;D)i0`j&32oI@}eq#HN>Ax^>4`y1sVHcyDn9)}=}= zSAsV!wguZb*~`g7;$qpvt{%m zB5Uw6^&t-16t1V--kaR0SEPgU(Cq;$R^af^;fclQbhv`TWit!&1*{<+nRvpE`N&!K zqOeuut%4UNQ-nbVky;(tWI0=dTLg{)G7ThfBX>-L97SH(cFY<|$P0bW**kt6mMNX= z9wnG(s%4mBe9H{gsoP|kful3<@>tB?OKiehY3gt)>w78Ru4AYDDtl3#rpJEHv_u!V zmYr=38EI$hBRd;5LfqAa_Vb|!$>*(Y(;-e5LAf}vrNHU+30$;}iG?AP!RM?_C*duV zR;Rr&pEbVFxJE`s-htpp>IZwF`)5Dc>$|UCF=4JDuemMFi2B~(tWG8n-G;ZFTu%O^ z>~&&IEC)~+uJ*!uI-N!r9}q=OfsJ5qXiLxQIelAysGqs0Zy}kdjMphGeZKVvlc{cb zeI=|%1!EzfppV!ad9L;pa??IBX%&P8LmElQGSeH>KC}`P&^-P=z2rp$TS%%^T-qn$ zI1D`B@C(wj0k+q*ROfK2HODkYW17|Bt4kjmoQ#V3@VvekAzb>IG{N;&sunCnFrZbr z0YhuuwEixZ7~_gajF<8Y5Z6}<&NdKHpL^h&QXlhwfPV6SO2I#&;2$HHj_}{4^dX9% zKK>5jLwE_=%3bg*psGhK&U_ zy%d8R&yafqyU+GQCQM|%ji!R`{P#<=`~klZ^RzAh-bh>KImX(Heee_a9^SL>OItcT zfM8BTT2XflF9f;esgL%;>2mE=FS1=PXDcK{D!ls8i^~TCn3%oUCHPL^Qmlig=&itU z9oxH2{*W1+p1r_x%#Kfgho3;F`AG^UDL93|3qeBhDN2Q@3w(w!XDIlH0&)d%@(KEu z(>bK%%L&703ojlBXE-h4Qm(+hyM>Ln4GZs%7u(L-S;yi0uTw3!x?!sAqn!u_tim5` z8wGj+B;`6!Q5mwpLegD!y*s&>VYrZh|x^7EZ-^_+=*pwk?S+(x)5>Xi; zDv_P7KhY6C*)}5;V?gO6t zu!%kGC9uD6u)&pLdtJf1BbsWM*28z!O@|Ls%`jOVy>4lHJUn6i^DUMK*g3br$%*WB zn&q1F*vQd~EHBj(b9rMMTuG2+E(N^@E8YP9yavRj6Mox(Nntv5!g zW%+30s_FQga^>-wG`9-yrak zvP~nO&EXLU{3mdzKsUgXA^7FC!0SPp{3KrDP_TLt$qt^zQI2piY%$W0I3v;~1!VN) z|BjHI_G9dSyj`IuDNx-ww@g00&i{}~&_e_M8U?RY@CF6{0Rem%mAoNo{wIR{F~K@m zbBncE%xi>6=1Nwr$@Nmr;GYn@la?;v3%H4B8_12re^p}PPQeol>@KrL8=EvFF9BDj z@7jZ4w-}raBsF|D4>nKJ?L$tBr1Y*pwU90hJkF@pgGJ>(3;O&QDEKY~zevF^QSi$Y zkfuQ{4R|0(TGnxfH2pVd9lnWO-xDD@-|O@pO;Wpn%(9JL=BpZBZyUz&w?sFTE4)`j zKcKSW!1j8B;gjM334n9n>uFO|WJ8c)MAb{Rd>Ou8womlTZgjDf0Vh3NKJxO^0Twh2}3L z>#57!M^uo7x*dk53>qq9tVeeMp~BQ|P(hn;!LAhOLrtt zE1_~)+tugXVi8(nd`tLja=hcUm|BO_q=%Gv>4a;vg&)jKcx{Gz)hu40;q?^EZ()2P zKLwYC$|&!{WD0JwqkQB)bIwH+o*l3c%4ls zB%RN!hi?NWz@z^IpG`WNLW88~0jmsYoN&|X>hujSW*DrRGYqMU@l$|%v8tVx4stp+ z_@@;72?BWf@!Xp`Qgid-xL}w3-D3DYb%XyAl?oSa(yriw-}aTy0nn7H9LIkjK+M;&;^j0XyMfG!n7uh%cx_Iw)d zUb>_DkFEK?r-laoLjL1iUvs5b0^Oej@K>mC)W5qyKd~xREC?wznITC5q6qq8T0T6H zZv>#mqO|S57x+^uc>IeCrt|d*j2G1ksdsH@ZF4mf453oNRQnH|eYLFzFAGX`TWkEk zQ41+7dmT^G)3^vh5%%AnCj7GU0zewS7gUxb#8HB^O16Q6m(ftlc^83u{p#%#N36oT z37PVdzQ~1o(nFYD3P@@2M&&EfbhB4&C%CWL4fxeLuKHLOsG%ja&IhReAO%AdkS;9M zb6=nA^-I%pqhQ|T8<SQaY{W9+kY?#_KJH_u?1)?py@RDE20Ta4&H@iO?y0?LQK) kB+TeXC1ef`_vVkpb`k;#SUBDjN8+*Zt@t;TZzym4UvZ<~c>n+a literal 0 HcmV?d00001 diff --git a/tools/inferences/inference_unianimate_entrance.py b/tools/inferences/inference_unianimate_entrance.py new file mode 100644 index 0000000..14ceb8f --- /dev/null +++ b/tools/inferences/inference_unianimate_entrance.py @@ -0,0 +1,546 @@ +''' +/* +*Copyright (c) 2021, Alibaba Group; +*Licensed under the Apache License, Version 2.0 (the "License"); +*you may not use this file except in compliance with the License. +*You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +*Unless required by applicable law or agreed to in writing, software +*distributed under the License is distributed on an "AS IS" BASIS, +*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +*See the License for the specific language governing permissions and +*limitations under the License. +*/ +''' + +import os +import re +import os.path as osp +import sys +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) +import json +import math +import torch +# import pynvml +import logging +import numpy as np +from PIL import Image +import torch.cuda.amp as amp +from importlib import reload +import torch.distributed as dist +import torch.multiprocessing as mp +import random +from einops import rearrange +import torchvision.transforms as T +from torch.nn.parallel import DistributedDataParallel + +from ...utils import transforms as data +from ..modules.config import cfg +from ...utils.seed import setup_seed +from ...utils.multi_port import find_free_port +from ...utils.assign_cfg import assign_signle_cfg +from ...utils.distributed import generalized_all_gather, all_reduce +from ...utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col +from ...tools.modules.autoencoder import get_first_stage_encoding +from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION +from copy import copy +import cv2 + + +# @INFER_ENGINE.register_function() +def inference_unianimate_entrance(steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg_update, **kwargs): + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + if not 'MASTER_ADDR' in os.environ: + os.environ['MASTER_ADDR']='localhost' + os.environ['MASTER_PORT']= find_free_port() + cfg.pmi_rank = int(os.getenv('RANK', 0)) + cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + + if cfg.debug: + cfg.gpus_per_machine = 1 + cfg.world_size = 1 + else: + cfg.gpus_per_machine = torch.cuda.device_count() + cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine + + if cfg.world_size == 1: + return worker(0, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg, cfg_update) + else: + return mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return cfg + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + +def load_video_frames(ref_image_tensor, ref_pose_tensor, pose_tensors, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval=1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): + for _ in range(5): + try: + num_poses = len(pose_tensors) + numpyFrames = [] + numpyPoses = [] + + # Convert tensors to numpy arrays and prepare lists + for i in range(num_poses): + frame = ref_image_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + # if i == 0: + # print(f'ref image is ({frame})') + numpyFrames.append(frame) + + pose = pose_tensors[i].squeeze(0).cpu().numpy() # Convert to numpy array + numpyPoses.append(pose) + + # Convert reference pose tensor to numpy array + pose_ref = ref_pose_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + + # Sample max_frames poses for video generation + stride = frame_interval + total_frame_num = len(numpyFrames) + cover_frame_num = (stride * (max_frames - 1) + 1) + + if total_frame_num < cover_frame_num: + print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed') + start_frame = 0 + end_frame = total_frame_num + stride = max((total_frame_num - 1) // (max_frames - 1), 1) + end_frame = stride * max_frames + else: + start_frame = 0 + end_frame = start_frame + cover_frame_num + + frame_list = [] + dwpose_list = [] + + print(f'end_frame is ({end_frame})') + + for i_index in range(start_frame, end_frame, stride): + if i_index < len(numpyFrames): # Check index within bounds + i_frame = numpyFrames[i_index] + i_dwpose = numpyPoses[i_index] + + # Convert numpy arrays to PIL images + # i_frame = np.clip(i_frame, 0, 1) + i_frame = (i_frame - i_frame.min()) / (i_frame.max() - i_frame.min()) #Trying this in place of clip + i_frame = Image.fromarray((i_frame * 255).astype(np.uint8)) + i_frame = i_frame.convert('RGB') + # i_dwpose = np.clip(i_dwpose, 0, 1) + i_dwpose = (i_dwpose - i_dwpose.min()) / (i_dwpose.max() - i_dwpose.min()) #Trying this in place of clip + i_dwpose = Image.fromarray((i_dwpose * 255).astype(np.uint8)) + i_dwpose = i_dwpose.convert('RGB') + + # if i_index == 0: + # print(f'i_frame is ({np.array(i_frame)})') + + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + + if frame_list: + # random_ref_frame = np.clip(numpyFrames[0], 0, 1) + random_ref_frame = (numpyFrames[0] - numpyFrames[0].min()) / (numpyFrames[0].max() - numpyFrames[0].min()) #Trying this in place of clip + random_ref_frame = Image.fromarray((random_ref_frame * 255).astype(np.uint8)) + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + # random_ref_dwpose = np.clip(pose_ref, 0, 1) + random_ref_dwpose = (pose_ref - pose_ref.min()) / (pose_ref.max() - pose_ref.min()) #Trying this in place of clip + random_ref_dwpose = Image.fromarray((random_ref_dwpose * 255).astype(np.uint8)) + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + + # Apply transforms + ref_frame = frame_list[0] + vit_frame = vit_transforms(ref_frame) + random_ref_frame_tmp = train_trans_pose(random_ref_frame) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) + + # Initialize tensors + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + + # Copy data to tensors + video_data[:len(frame_list), ...] = video_data_tmp + misc_data[:len(frame_list), ...] = misc_data_tmp + dwpose_data[:len(frame_list), ...] = dwpose_data_tmp + random_ref_frame_data[:, ...] = random_ref_frame_tmp + random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data + + except Exception as e: + logging.info(f'Error reading video frame: {e}') + continue + + return None, None, None, None, None, None # Return default values if all attempts fail + + +def worker(gpu, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg, cfg_update): + ''' + Inference worker for each gpu + ''' + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + cfg.gpu = gpu + cfg.seed = int(cfg.seed) + cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu + setup_seed(cfg.seed + cfg.rank) + + if not cfg.debug: + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + torch.backends.cudnn.benchmark = False + if not dist.is_initialized(): + dist.init_process_group(backend='gloo', world_size=cfg.world_size, rank=cfg.rank) + + # [Log] Save logging and make log dir + # log_dir = generalized_all_gather(cfg.log_dir)[0] + inf_name = osp.basename(cfg.cfg_file).split('.')[0] + # test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + + cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + os.makedirs(cfg.log_dir, exist_ok=True) + log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + cfg.log_file = log_file + reload(logging) + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[ + logging.FileHandler(filename=log_file), + logging.StreamHandler(stream=sys.stdout)]) + # logging.info(cfg) + logging.info(f"Running UniAnimate inference on gpu {gpu}") + + # [Diffusion] + diffusion = DIFFUSION.build(cfg.Diffusion) + + # [Data] Data Transform + train_trans = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std) + ]) + + train_trans_pose = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + ] + ) + + # Defines transformations for data to be fed into a Vision Transformer (ViT) model. + vit_transforms = T.Compose([ + data.Resize(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + # [Model] embedder + clip_encoder = EMBEDDER.build(cfg.embedder) + clip_encoder.model.to(gpu) + with torch.no_grad(): + _, _, zero_y = clip_encoder(text="") + + + # [Model] auotoencoder + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.cuda() + + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + # Here comes the UniAnimate model + # inferences folder + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools folder + parent_directory = os.path.dirname(current_directory) + # uniAnimate folder + root_directory = os.path.dirname(parent_directory) + unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth ') + state_dict = torch.load(unifiedModel, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + logging.info('Load model from {} with status {}'.format(unifiedModel, status)) + model = model.to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + print("Avoiding DistributedDataParallel to reduce memory usage") + model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + + + # Where the input image and pose images come in + test_list = cfg.test_list_path + num_videos = len(test_list) + logging.info(f'There are {num_videos} videos. with {cfg.round} times') + # test_list = [item for item in test_list for _ in range(cfg.round)] + test_list = [item for _ in range(cfg.round) for item in test_list] + + # for idx, file_path in enumerate(test_list): + + # You can start inputs here for any user interface + # Inputs will be ref_image_key, pose_seq_key, frame_interval, max_frames, resolution + # cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank) + setup_seed(manual_seed) + # logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + + # initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x, + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution) + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + # create a repeated version of the first frame across all frames and assign to image_local + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): # Disable gradient calculation + temporal_length = frames_num + # The encoder compresses the input data into a lower-dimensional latent representation, often called a "latent vector" or "encoding." + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() #use without affecting the gradients of the original model + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + + ### encode the video_data + # bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + # misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): + + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + # print(torch.get_default_dtype()) + + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + # logging.info(f"Current seed {cur_seed} ...") + + print(f"Number of frames to denoise: {frames_num}") + noise = torch.randn([1, 4, frames_num, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + # print(f"noise: {noise.shape}") + + + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + # print(f"offset_noise dtype: {offset_noise.dtype}") + # print(f' offset_noise is ({offset_noise})') + + + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise) + + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + + if useFirstFrame: + partial_keys = [ + ['image', 'local_image', "dwpose"], + ] + print('Using First Frame Conditioning!') + + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + # print(f' noise_one is ({noise_one})') + + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=steps, + eta=0.0) + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() + video_data = 1. / cfg.scale_factor * video_data + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + # Check sth + + # print(f' video_data is of shape ({video_data.shape})') + # print(f' video_data is ({video_data})') + + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + video_data = extract_image_tensors(video_data.cpu(), cfg.mean, cfg.std) + + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + + return video_data + +@torch.no_grad() +def extract_image_tensors(video_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + # Unnormalize the video tensor + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + video_tensor = video_tensor.mul_(std).add_(mean) # unnormalize back to [0,1] + video_tensor.clamp_(0, 1) + + images = rearrange(video_tensor, 'b c f h w -> b f h w c') + images = images.squeeze(0) + images_t = [] + for img in images: + img_array = np.array(img) # Convert PIL Image to numpy array + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() # Convert to tensor and CHW format + img_tensor = img_tensor.permute(0, 2, 3, 1) + images_t.append(img_tensor) + + logging.info('Images data extracted!') + images_t = torch.cat(images_t, dim=0) + return images_t + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): + if use_fps_condition is True: + partial_keys.append('fps') + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] + return partial_model_kwargs \ No newline at end of file diff --git a/tools/inferences/inference_unianimate_long_entrance.py b/tools/inferences/inference_unianimate_long_entrance.py new file mode 100644 index 0000000..5e841f5 --- /dev/null +++ b/tools/inferences/inference_unianimate_long_entrance.py @@ -0,0 +1,508 @@ +''' +/* +*Copyright (c) 2021, Alibaba Group; +*Licensed under the Apache License, Version 2.0 (the "License"); +*you may not use this file except in compliance with the License. +*You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +*Unless required by applicable law or agreed to in writing, software +*distributed under the License is distributed on an "AS IS" BASIS, +*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +*See the License for the specific language governing permissions and +*limitations under the License. +*/ +''' + +import os +import re +import os.path as osp +import sys +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) +import json +import math +import torch +# import pynvml +import logging +import cv2 +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch.cuda.amp as amp +from importlib import reload +import torch.distributed as dist +import torch.multiprocessing as mp +import random +from einops import rearrange +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.nn.parallel import DistributedDataParallel + +from ...utils import transforms as data +from ..modules.config import cfg +from ...utils.seed import setup_seed +from ...utils.multi_port import find_free_port +from ...utils.assign_cfg import assign_signle_cfg +from ...utils.distributed import generalized_all_gather, all_reduce +from ...utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col +from ...tools.modules.autoencoder import get_first_stage_encoding +from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION +from copy import copy +import cv2 + + +@INFER_ENGINE.register_function() +def inference_unianimate_long_entrance(cfg_update, **kwargs): + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + if not 'MASTER_ADDR' in os.environ: + os.environ['MASTER_ADDR']='localhost' + os.environ['MASTER_PORT']= find_free_port() + cfg.pmi_rank = int(os.getenv('RANK', 0)) + cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + + if cfg.debug: + cfg.gpus_per_machine = 1 + cfg.world_size = 1 + else: + cfg.gpus_per_machine = torch.cuda.device_count() + cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine + + if cfg.world_size == 1: + worker(0, cfg, cfg_update) + else: + mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return cfg + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + +def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): + + for _ in range(5): + try: + dwpose_all = {} + frames_all = {} + for ii_index in sorted(os.listdir(pose_file_path)): + if ii_index != "ref_pose.jpg": + dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index) + frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB)) + # frames_all[ii_index] = Image.open(ref_image_path) + + pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg")) + first_eq_ref = False + + # sample max_frames poses for video generation + stride = frame_interval + _total_frame_num = len(frames_all) + if max_frames == "None": + max_frames = (_total_frame_num-1)//frame_interval + 1 + cover_frame_num = (stride * (max_frames-1)+1) + if _total_frame_num < cover_frame_num: + print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed') + start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + end_frame = _total_frame_num + stride = max((_total_frame_num-1//(max_frames-1)),1) + end_frame = stride*max_frames + else: + start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + end_frame = start_frame + cover_frame_num + + frame_list = [] + dwpose_list = [] + random_ref_frame = frames_all[list(frames_all.keys())[0]] + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + random_ref_dwpose = pose_ref + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + for i_index in range(start_frame, end_frame, stride): + if i_index == start_frame and first_eq_ref: + i_key = list(frames_all.keys())[i_index] + i_frame = frames_all[i_key] + + if i_frame.mode != 'RGB': + i_frame = i_frame.convert('RGB') + i_dwpose = frames_pose_ref + if i_dwpose.mode != 'RGB': + i_dwpose = i_dwpose.convert('RGB') + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + else: + # added + if first_eq_ref: + i_index = i_index - stride + + i_key = list(frames_all.keys())[i_index] + i_frame = frames_all[i_key] + if i_frame.mode != 'RGB': + i_frame = i_frame.convert('RGB') + i_dwpose = dwpose_all[i_key] + if i_dwpose.mode != 'RGB': + i_dwpose = i_dwpose.convert('RGB') + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + have_frames = len(frame_list)>0 + middle_indix = 0 + if have_frames: + ref_frame = frame_list[middle_indix] + vit_frame = vit_transforms(ref_frame) + random_ref_frame_tmp = train_trans_pose(random_ref_frame) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) + + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768] + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + if have_frames: + video_data[:len(frame_list), ...] = video_data_tmp + misc_data[:len(frame_list), ...] = misc_data_tmp + dwpose_data[:len(frame_list), ...] = dwpose_data_tmp + random_ref_frame_data[:,...] = random_ref_frame_tmp + random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp + + break + + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e)) + continue + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames + + + +def worker(gpu, cfg, cfg_update): + ''' + Inference worker for each gpu + ''' + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + cfg.gpu = gpu + cfg.seed = int(cfg.seed) + cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu + setup_seed(cfg.seed + cfg.rank) + + if not cfg.debug: + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + torch.backends.cudnn.benchmark = False + dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank) + + # [Log] Save logging and make log dir + log_dir = generalized_all_gather(cfg.log_dir)[0] + inf_name = osp.basename(cfg.cfg_file).split('.')[0] + test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + + cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + os.makedirs(cfg.log_dir, exist_ok=True) + log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + cfg.log_file = log_file + reload(logging) + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[ + logging.FileHandler(filename=log_file), + logging.StreamHandler(stream=sys.stdout)]) + logging.info(cfg) + logging.info(f"Running UniAnimate inference on gpu {gpu}") + + # [Diffusion] + diffusion = DIFFUSION.build(cfg.Diffusion) + + # [Data] Data Transform + train_trans = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std) + ]) + + train_trans_pose = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + ] + ) + + vit_transforms = T.Compose([ + data.Resize(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + # [Model] embedder + clip_encoder = EMBEDDER.build(cfg.embedder) + clip_encoder.model.to(gpu) + with torch.no_grad(): + _, _, zero_y = clip_encoder(text="") + + + # [Model] auotoencoder + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.cuda() + + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + state_dict = torch.load(cfg.test_model, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + logging.info('Load model from {} with status {}'.format(cfg.test_model, status)) + model = model.to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + + + + test_list = cfg.test_list_path + num_videos = len(test_list) + logging.info(f'There are {num_videos} videos. with {cfg.round} times') + test_list = [item for _ in range(cfg.round) for item in test_list] + + for idx, file_path in enumerate(test_list): + cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank + idx//num_videos) + setup_seed(manual_seed) + logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + + + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution) + cfg.max_frames_new = max_frames + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): + temporal_length = frames_num + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + + ### encode the video_data + bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): + + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....") + + noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise) + + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + if hasattr(cfg, "partial_keys") and cfg.partial_keys: + partial_keys = cfg.partial_keys + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + context_size=cfg.context_size, + context_stride=cfg.context_stride, + context_overlap=cfg.context_overlap, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=cfg.ddim_timesteps, + eta=0.0, + context_batch_size=getattr(cfg, "context_batch_size", 1) + ) + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() + + + video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + text_size = cfg.resolution[-1] + cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_') + name = f'seed_{cur_seed}' + for ii in partial_keys_one: + name = name + "_" + ii + file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4' + local_path = os.path.join(cfg.log_dir, f'{file_name}') + os.makedirs(os.path.dirname(local_path), exist_ok=True) + captions = "human" + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups, + cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps) + + # try: + # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size) + # logging.info('Save video to dir %s:' % (local_path)) + # except Exception as e: + # logging.info(f'Step: save text or video error with {e}') + + logging.info('Congratulations! The inference is completed!') + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): + + if use_fps_condition is True: + partial_keys.append('fps') + + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] + + return partial_model_kwargs diff --git a/tools/modules/__init__.py b/tools/modules/__init__.py new file mode 100644 index 0000000..db82a43 --- /dev/null +++ b/tools/modules/__init__.py @@ -0,0 +1,7 @@ +from .clip_embedder import FrozenOpenCLIPEmbedder +from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL +from .clip_embedder import * +from .autoencoder import * +from .unet import * +from .diffusions import * +from .embedding_manager import * diff --git a/tools/modules/__pycache__/__init__.cpython-310.pyc b/tools/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..550286961e2e242d1f42613892457173415c90da GIT binary patch literal 438 zcmYjNOHRWu5VeyOqS7GPu;Kz?!2u9LR4t#13Xq_iD3ayYPBqA%$aYkfHCxu4fh%Rp ziYu_fra>^4-}8H(8IP>9tdD%WyuMXEjL=6A|0D9^+;4mBfgy$klDNP)oFfs*xQJy^ zBtJS9=rQ=Qh&t|dyw&lx=i~w9y-)uZ%5i&eY4kH!cQsd|>E!-G&N*Y;LmbTd6)PY(ONfljk^Nvg!3Icq==%k!kZfpq_o^(0VvRcn3WCe=pl_j+{ zj?M-Dy(kxrby^bTXSBa2nJE}5kks+gX)SCiHEV=VAOsXRLXPWoo1Sjmv+zrH*TVqq J;j~)={0%uweAoa0 literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/__init__.cpython-39.pyc b/tools/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4c514b48df2ae693a08da97bb7c20f6dc69e63a GIT binary patch literal 382 zcmYjMyH3L}6t(jxQE4PtnD~J(u&^P7s9Ij40wm~SMT*?Ui7xUZ@?(IRm6>ngm$EYP z3ryHF5?ssY_#EAPY|}g+BVVs??@fvk`V8WKMP8iwZLU1jsKy)#OfU}Th({tOu}Dbr zt7DEHgCFx~;B?^8z+0Y^2UMhA{w`GDe(yr77bfpICTEN5`*X2kbEBl(&`SFs^QhsTp``q!#s@8~4hiIdEd zaScu5jE^*AcQXgFtpt^{W%>i@Q5XpDeKb2RmxgImf`KPhE+tGQ2x?hrmExups&<^2 cvZ_FWt*XQBxzA2E?rHe1nRuAs9h?p82haaz1^@s6 literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/autoencoder.cpython-310.pyc b/tools/modules/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87ec4adb3dd92f9976335f46500b40709989b970 GIT binary patch literal 16687 zcmbt*36NarUEjO<_3NIVo}Qi=&5Wd#Mtk>~m9sL&Ny@c{rPW5*%)+h~LPt)gr{9}V zxB3`;?~Sw?zeyHL1}YU{a4<(Op~pbQ8tqcW>w5qG?(=_t-|NlJ<_$cr{rVR+zp!Z-zr)D*lR@NJykXlk4Bzn0 zwo%ppX4RCqRkh^buG;eNR5SALcCyv1smpqutYw%+Cs)lQ#q#a;Os7ySbY`owh&z7f zmQgMG5Bgc(yJc4A0<&M*H~feEygzfxylGbFgF?T&Z@ggm1%LLI;m-!y*1}C2Ig7)b zqCbb6xnNP!mr%Ck&m(0%Fk4c(dL)=lTPgbs$XN)MrR)mIF8WJIS<>xRSCMkWUq;Gu znsO8=EB-1{R@0PYNI7~>$}#^q${bJ297mZ4?rHgie-dR*rez*LnKl0uQck5QCw%i| zqw??_?!)>W8?a&}h38(m_+st(^|LRYd;X=QaPGp(7hk&YkykEWcyYaftrsHYm2lWf@B>hzj^#Z1gs4-Az<@i*IN|Lmn#!a#+W znqjlmyL7hK*}U<}h1&II*Y8~qYuh~))i>I~AmdVFCyaWXTDRv1;UKa0O1Jr3x7n#j z!KJ9zYloLQJ%6W-0`;A!7jzpatWIy=NDAFvx83Xpb=8b+6gXxJplGQX{M}gtXfR>$ zFzn)$*)1Xv87;GA#h9OO60vSOSUP8Y>-O*cbpH>&@};LQR-7c?-Rab}RBvZHM1Kf3 zF4w!=pdBW5fF!rRxrq@}Y{ixntJ_U-FR=7_Pj#?ZGtKU>hMGkYStkuhJV}lP>+R-l zP(oBDuxKid?yGnI-Y5Rwm%jOzo|%{gcXYT#=MjmxN#;#U&BjLG+Bfz%A&c83Id;oI z-fBFap7G7djg6b;kAR+j6qN8uvtqCB9!r-NyA(!&YWCDxZ*wiWu^p^Ew!8G$TDKRi z-3X$!W@oz{bb@Xa_?2vu$4+;{sNQV^$?}9ssJDaKX1xJ4Ow8Ts^<3Es zqS|Itg;5QI+6rpYM6%ppMXWLkI}L>?J8LDu`=s&OthYF3;9v?IsQHI3N(3G;hE617a^ zPQhKxY5{ShHXFOXv16)l zjf$~*(~Rb>Sm8IaMPP5q&qQ8--gi;PK`m$B!U(fc{^M~Lwac-`7(xp%BnH~~YtdqB zseeS^asnnU zIoVFuNV3msVo0(nRV3MSf$+j4>+67!mYU}f-Lt*!wKIN_zq(WJMzscF$s7s0HkObJ zIvass^P3%YjFoK6AemFa7RWYGwT+$4O^hw0>RrE^WIH{K5JZ>V2|-q!dSdm$#M!P# z5C(ozb?coV$<{YQNx_6Rn^?kH;+(i90zu-i3z3a;$FF4S6g8+ z-`G(K3=1<=0U5`Qq_kb1N-3&dZ>+?DndEA6#_{ z3g!v(2*Mu1D}b&EK4UFf1^kL;*>wJJ_s9gPo*wLYgDidUkns%J`z|7sQ6xOyx(x}n zYV6sO72AD#AMAV2iJX20EI$)F=ZudZ*#qJ9K|1@UpShAz??Wu>yNG2`FAG$6R}GMQ z_QS^2B_O{S3!5VHMk-v&lD>cYqKV60JqwQNMm z{5Y$Hm=4;C6tY{sC;jxDyLc|5HJ37>5$zUV3LyMd;I9!a)+#evGC1vC-A|l0a4>O5 zD#;ACzm0u^5JBWc1@RyF=h(di;p>Bjw%bi0E-RAj4+fV75p99>Cr-1w*{gW!84OMR z83N)~VuikX8G))B`n7IwJu$CTGglazk}B$Hl=^dcLrRe{HX9Yb?(VKl4)@*)6G{gM!Xck!D{BB4ECdBuDp)#( z4V=+ckZn;u{+$> zAB&_=${L!P|(9(2-$E>55kmwAXrqh%)Q$CfI9DMO>$bGi;f z)WaRj0NMreGe9YO6U(eA+ED?TQY8JvAz3SqCS7zWaNX^ls9Ms>@<{U?L#ivRKD7zi z&7Oj=^A+-B3y4VA0*l+QyEM54Y1xVCaq&4tq48%Cpcx_Z#)$0}3~ArIVyQW(T^gS- zN#BNCf&sqow4gk;GQNA;?7Nt{Y)~b;F6vRmI=wDJK|PFC)q?;K0<)c<136s7P+;i6 zjGjFcE|8`D1agJo$iY2`ZEXJ(iVn+8Fm4JqzOq83STb%}kb{l21rMnmcCXcXJ5l0Z z51Lz-qXSc8=G2xI*$wmyf*XPQB*#U4GJlUWnJn1gRt+BN^qjBI%?gG`-5wZ6#zxAq zzyzokP}?Ogy0I00XKYaSzGUrX_q^D}#zC+5tt*Z?6K6-U@bOeHM}4Vj^xZg%5H#_A z7Mgl4_WJpK$gF)6s)P+zlEuG+e>;2u{mm%!)GzECpZ);!_(7k!z5GO8&O8MT4TJs)$X4!(3}N z0!2zp+^APmTN^^*8k6&)P1oAZD}k_pRJ>D_4I4d}y)znFuCw|K92_?^+J;rCGaP`( ze9%j8OEA6ayoyl@d#HL0%E)JjB2}tPpq1X0fEWrOi!#_oUi4W0 z4^mbTKK@6Z2ZC9dWGJKNCYU;xnmr4}aB)a6-~s{a^X;nYPe>%%Y`9wsN6(rYyOV2Cm(^J2w1f4;In zbCKp2tgC#jroCe{9cwi1+B!!J74;IwN$pkD2t@E~G1MfW+)+IO!Qq%8VYk_C4^FST zufCH>9Bne28bghs-5wZ3f>jq-gzQkA1mG*^|D#2ZooIE5$7 z-PP$yjjT%2-r$k)h=`Clf&l68ekzqIQcSf_Y#MEeFF%7<22i%91gEU(X{|a!dysb^qeJfS2zh<$ zbEbF=ZQr?-2f`9@iLgAq0K%q->-sZ(0r9LLum=R5P0AO1a4>kf&(pj4GM$^};iM>B zy9xmTcMV*k!BAf@(JO~QESSS9hAX*ejm~0m6l;vtDygl05QWC@Hr1si^)E8(c>-CP z(Kc`ye+jRk0}adFWA~2n{us$I(&l)B10|=HvB*XYba@XsYSjN|k7RqedH<3PaF#%T z)kAvMC@_JJ>~7Gv`XNL{{i+{kbgCO-{jV~e&fr5kIRxj~313cd+{%Y18D3=b^X4XO z7!gm_I;^ktb@c-%d0796CmuU^8-Igskw1$OfB0$-%KnLF+(RF~$>GS1j~=qDsRdEV z9ZFh4Ra`dpBLr;j{zUd!Z(&jw*J=N`f$BYBp^35JAKmVzi z2Z9dV?D{}3gXPX)xlMf_2$#GZ1HvUwxlj@&bAG}SVhgF4P+pt;e{oN2V75GS)f3qpf|^R3yy5!b8Ri5*6MVzW>l z4)0#5Raa_rx(n@&Qhg@;B*7HU@;ZozdyZgwy9U6_e2Pt;& zklu$o^aLVYIlN|MiJ>EF+~Tgdt*md}#_8qqKpV8y)sLYDyn&y`pZZ$_q^T*rPS58x zW_%C9)TC&Jxb~G(7R4U8x`vlRLNPeXTh`2Ai!Cu95BMdQ4_=?83YbF4D|?f zV-q&E*@sgB+OfS4>}$Ep(2-YI3%V|_Jo9Z?Zf!>J>w3nJxDiYe08qXuJia zWS6JffeoLc6qev`Vu%_CHxYt&kEo~J5cLzo-ez z$oP+VQvri>nS9CNE}+Opky&p7MT2V z1Q!TK3rG&`q1QJS5y%5?aLi>5EIS;ezw_q8X#OLP&3zs`Tma(AS-fwY2O!d@Wms3_ zY*q^kK>Lca6_a;hxf_tGG!evxPs@O7L1Ct!(bnk)*>w;B7eq6oVMv^CI11`Uk_%=- z^hh$->g}CCmJvz+g!TR@09H%p%@OM3)BNU6R~$n1HkcO-;V&aA$=9!K!6VyipYaEa z=CR7ZAebaLQF<1*XhmeYU|vVStQ_1paLhN$W6PClQU~s#IR5Uc(?3MhxF%)1!GU}5 zx|e1Y(;9du9`ZLnV$u$9dVp)6vjw|piO?d^mdFUZbIA-2>bgPn5u5;CSwG*TQBkpz z0z90+4a=8-;^ zw!>b>+bNMU@!D5xwT*vAZQQ5Fc~a?E9>)SVmW-x6py0}Yg2`0WKPDiqRu)ENs)YXj z4N{YN9%N4}sN4ujGPpj_d1{)7Gc5U3;(ko7LM1s_`fb%~>GEfAPeApy>5tC(X+Wx+ zQU8{m6}q^iG}ySo(a&%!QbzqJ4*1^@9KsqNz)Zq*WacfNg_Aw-j?I8UcyDgqEp8LQ{cH_ zilCOM9wRdKnGck}`8b@&p90bi;3m9}&=*+-hdEO%;V4I1qBA@VwW!s5K$LI6e@=Z2 z^>CJsQuGbD5~13^L{OHRFcgcs`HGs;a<{{>(m{>g%H zWeNnR59$tKL}$50YbYmYLSRIm^{8x;xOcH_NOAJ-c-HR1_@bx0N&OmXj`26a0THdV zNO_@1pQoU8%Pie3J*t>~4-*>focagIh0x4B7e!rRD)*?x1OD*FUq;bs{(vJ|%>%#< zsNkn^QyXUF0p5MmbN5Nl-zR;B)O+Bh=LV!#q1=PQG|^HnCJu>nnl{a=awL_Q*AkB; z$z!2|)G;hHb_hNF1@t7!0!8gKB@HNqRCQRm@Yb8<;ku5fW?U8%+Xhtg#{WVDR0CEA zs^K|4L!gTk$H5j6(npyL0sC9dC1=m&(GiYkY)EPC4b4F@|*~JIBmE^=*c=5dcb7H(46cy!>`X2vI50!Xqdkh4h@GKtHV z*c+XW`T(P%Q2mM&3wxUpod4>ph#jOyIYOY;<=RxkiMf%O4WZe@yq;Ls(`!aTzd~M= z{yW45j+LaVe?^<-Gq0hP0LwPdZAx0jg{|1=g0OJU%zQ!WIOYnOH|Cc~kXmXjp$mafl znSV%dUyk)UazD+oG=|=5riu5;_Iqgc;U8P8A805;tsYm6N2gF)C-_;8{htwBAQ(;Y zaYTO+!xleos)CP6PPkF#2SYua=($DIzRuze*_yum+yXDv5P|cGY45eIT1U1~a_}%bMBsmfi250VcioQPLeJ{wS?r$!Odv3u zHmJ{uRL?2>jb@?q4(WR|)Qmz|`-5lVyJ$fd<}*hm+tVXgQ_mi*M9- zKj(_NPc@(F{)`+sm}BN_LA?(xzy(jPBev2V2cbLMCUIe0JfCE*)w|7hyIygVEN+-} zcG{6>EJ+C{0#71t{0~gdQ>dn|X6n28HJqp7mONpN_q>P3!7K{WrYGwo3kZJ#(;O52 z1(k!nI16zPBLNqV<;pUTju65PNI+MvGxKbYN?-&=LjvXAr5MFR#|o)-FWE!U_bRdVZaB#ZT-URWRBlm0Izy&`H&z?ifKY^gC=R1 zf776 zIVrlVh`hwrMRC`N-Qof!P*`|Ll4nINDR=?EPqI>^2CGc){Ha008EwdU)D@`0fKLv0)2N%7B@&t&9;Fz#$x z1pIk#Ic`C6XaagrSa_*~#ve;~Z#$@2#9dllp9&4*d8yA5`L_!W|8*F52UuEug|~J~ z&!^u6-S;7$Yt@hX$Q4;A8vMB{RFbNUEwxO*3ukJT;3&Z{0@{P}1(wF` z^wu?nM|B35`*-kh!)8#A)PH4rmkCa>#=`^~j2ryk4oT#bqByM|)7OTdVz3&|g;ev9(t}5k08qS9>U3p?P?z##iX! zRK}_dc2X0p%g|6g(^dpUt)e%$^2kVK z9n|5#7hMi4o`7z;yQzN?zC|65(+eB42h0T9cP&1 z4ss&g`9*FnC9V%aRx+OVp+pM-of3HopF8AbLCZ^#l@3&G+y<8PVk#@Yd6c3{@&?$| z_rrTn!uj|hw?C&xa`jQ9lvG3Hm(icssZm)z+Q69m3&RmFsv8l$2u$Prq}Lwv;8m8j zKc^rUuLr0ulJ78KR^jGPi9|t&DUtXD%L&d+ia$+PR7c^V1=9v+m$SCvM+#aDirgEC z!9yhAXNhg6(T-<1cHS=H-XvT=z_KGE38^`iAPHS)OmrEF>ktBLcW!_o0~vT<@kdGY zP3(md1ee`J{K235cIgF?Qn-X#Ko0(v_;*tMY5Zv>@$C#<1i*>IPG|du`Vzw)F7G!k zpYC?WJy2a3L{E>-*3_@C=3fPmu z;irnav~Z`B%>I6Y&$5i$+i2n2F@2ek#`8#_7akcUxS?+6kdJc`Gk+jn8+G{ypJDw00yo3P3*lT;e!HJP6dkth*GR#i5q#2E7 zv|qP&y`lvayCDg%kP0bMfPqm+6;Vh*2xoByD1y0fe)JDWRVb3>CO`m_#5~XY^_rj{Ijf-KuNqGTBYnGE8IBtL7xdYUei#)xu`6T14FTo$E$*%D>Zh{p@wKS_+Kb z^q%3r%+LAx>*h7HIujInvwOyC4Zq+QuN!_ba9ibTHge_$Ia7WKIi+Au(&tfj+Mhwn zOklL6baf#prmf8SW#p8DBT{w|W#{~Pq|EDft4l~(@Q)znNSbmKDU1FRQkK$`V@NrA zAmx~U9A%EDWsalFZ3kMu-9LdcC(<&vq0AlrGE$b)l-qstDWh`NO|Jh+#Y&0~J$?R( z+M_FvJaP8XCzImY$DTU>59)7-2_!G);X zZHE^&yZ&}N2rtyPqi)b?bp1e`+Paz)JKavZ*$L{Z8C@-~cMHI<6mjb2GC)HT5a7eG zi&w@uk3eL!%$5~nq`v7}zJ0^N{GFAx8-M@Py?_7E`yV)8aguyzd$YEty4zbJ`a`&} zUhi~*c9_@!lHB_0Dn?MT6)`ojI-Ml2wz^>ysAg9!cUPCAt6RbH%Xb!D zzTD|X%U6SFxw*O34mN{M6!?{FlE;2^!l>S91j%h@oAtGBr`~>TeLD=B_0HL57^&vP z?Wo!9ByLz|O}$F)`JfYaRWi5IjUF2|_^47ybl7RduFf79QBn*W^>$EOtv7&7iMcbm zmP>0vR9kJTFsfltYe7w#Xm-}PoFP}=u<~Z!ESj!0Z>rl7Z!mH6Awu#AzYicV$Yj7^ z%r&eUkc2^=8_?hdusvS7j7}HE7<9mJ%Y{m8SegXek|AbE#VOBlO7^*p6 z*g!Q8@|*V;kW=uF_=_l8^q2gj_@4?&!ibiAlC>}*kn`=xnSC=hzgotg-I+{1G^B>t zKfaP=Krw5pNYKQtnZ?XgIdm>g%M_@{LjZ()Go&!V+hT~wxlPq(VCC# z*zvPBP}(OZdn<|OcegJBOLlIO-jn>!fc|chQ6*lZ+wLmR+f9-Z_)wv_*6FICwh>$n zD+@^`+};XQ;?-)+PBW_2l42VSrc-No*Dlvpl9N?t;UxR0W<^OhMb#vGHW22NWPKeF zd{;B<*L|ehxqRAB@|U*jov7A8ESV;?)kYF>!RE!luldbQb&QqlW+zIfRj>wH3RLal z_UbCemQnSN-$}BYU5pTvklhY}>znnYBtu?RU9hQ*t*BB=oG@swVoiMXx7u7&r`c+L zO9hdtgKzoj9wew&Fm~4-9YDiJV>rj4=~CXN|9W% z@Rzx-%awxWp-#J7_YXZc@PhwSxvbN8||wAsP5*(L~}vMEOAMt#%XG#ERsP z>EMzqFnbM5GI5%n)o#U8UxJ|}&Q?8I7wkx^&{vNkt?Gt;trJ{H%#CVhgP{oqtiBkf z&fpEp0ESt{9((we(%;T)<9j=ftz$X`cbo4Mctgr+?1Z6az$(B3`fIO<0~Jes5&n{# ztSiX?!dOp%XsI4V=EOLYa<@|x0;pZAM~!vJ3IFvd7P379D7mvdKG1{3$7C`^KL`5= z-f$TJ!vr_oGk3xFFkm2yMo1!K3?)MbzHh5nqOAISf+GO?85y_ABAl4el{UMSv|*AP z2<6F%xSI3GDv9KRbjxGKcjm_@po@JJ4deNz2gits-3>-hPB0ifI66jeWAqNl$-(F| zkqy$Yu303abRHFFVjdz`E7{>>)N!<`-at?#_*#P72sma9L>VgUaf%FAT82N35~0`# z>D2NEcIRNFvE8A{H((aze|mntxo2U2EZ-u#I5<-<0_5Fq#Fl7C=hdSe>sJF*+~L#` zYdxHt*~FxB>G(Fo3#hEs5Fre16t-&du&t_XhGUc91o#S?E+a75B+vtS->L}#Zc%_^ zGIoD!EOj5ss#5^_jvlqFJBcE(!NDhYz|~ycM}EkKGqCFA1L&j+ogYIZmL8zdaG8nv zv89&4jiGkzIvo!os;LfU0O5`N3{c8m#WHJpb`;l1EdnGCiCQ%fOcxzmPG@U7s?KPA za;UD3Ak{@ypV)-#YF9x_`7Pw7TaY&$lM)9AwJ^Q~Y1y&qaq&4tq46^aPo5b8Wl39i~;1%qkdt z1mYB4m=Q*&o?;gy_{sIjQFWH34UwA?967fKDLsjzgR)~Rn-Yz$tPm-dESm;ae`9UI zLn>gM%eC%yl(<)d=GuC+Z)(h(*s>D4fjU8OHBc1usd+RhN5%^_xK;g!iZ9z0x;cU2 zQ2_?Vk+G3-EHDA;17H9x;-VW{Z!^PkY*0nLVC`miz1YRpLH+fu4M(lU*mSFtkzSu0^qNCoP~+3S@*{m^&_@n^x%V4i4b|Ae@0x8Io$`GK=I4Drc8b3@ z;Ei3BgSHcAM1j_5NX&$>E=F5eg$9aEgF>>c)r)E&&7ruh&Q0WsD(Y%vs!SygRUp+x zgQ91u4K@uvbGg}op(&B5TT^Qnh0ZlT=S9P8=UBSDj}tN?{UJk3kvv)IhLGl?epQvnf&c!(H?fQZ^cv3!;Dl ziMM2xMJ7>sEh5xUSu*b>Utbz$D#OOc7(0y=%08i*2mw+Jc}Z9jYYRx~L*9n>W1At5 zBD?4Kc5DJ29H6AwNTeaL5GT|E+HY+AB;2V>kz9HwcU_iuZQ(0QtDN#Lw2OTkbqmO zH`_dyxFq`1zOk=0DBrJuK2ZWJE9t=?iWuC8C z=gw8;Xux&qn*j{Ce66M(NHpeZG`HG1M{E=IW0_;WPn4ER1<&!^eAe>- zb6(c7foeyk9)~jaux&o!0|2-1RXC}y0<3_(3P0<6@FO_l;xwum@h3o&s}z!%5q$ls zrXO@y)E#K|{6UWbnJV9qdW8=bM5an51a>||I>b5HVC=>DIJcUIT2i1|lA24_3N}c+ z92I5tfkc!~>NNz^8q}8(+(STmQ}+|l5LI7Bz?q62s`dIajLH(r>fVd!yYOm7zUY`c zM<*vWG$=`V{f8ohuZRii;-tb`DOIM(F;z#g2{9_z@)Jnc0|0**6P>cEr?m1k#so+- zHeO)@#}ELqqkC|7D5x+UpoEW^4p>1k9XL|;H3YH-!-XIc{|w%cNzmaeXHwJDW6`Q9 zboDo(nEFm7iSoi1^+gLdU!*Qg?b8yKpdFHt4&OB&KiBG;z!l^s#Q{3-9VuctW9-kNu`ByOKkiVHZGH; zS$X{67;lY_9HSU^tWGX1W07s>>+-@khy4$CNH&LC_v`FmFdejTNbec}#*mTSy-45c zJCQ$v4AK?%_d_qy)ggG#uK3C{mk=P?!I2H8D;%!oyUbNsG9sX?6?myvR@8T(&|xD; zJn^Ezt@m3fk(w?P)rYQmzwBo`ZwiLrXNTz$A3J1;Q){A`H>F};Dc%4RcAh`Q$GA4438@agL&k)>I%vN*hHwMd#16z z-PwTs8Psuj0o{$h-%O(og*ch=Sr9r5d~eML&VioUO6)N56PtzdJAvwkT6Nt(>$(f= zl2Y$r!uJvkw^w%dUZnjQ-jKdxx}cU|cDr~>t*md}z){^|UmLVm)OVu>T!G(%KlS|tY->WV)3f0Us_O(3Q(8pYU*b*K z-#zBeT`xKfZBnd^(IjW#;uQ?zU^oj^1}@+;b0%hI$=zV-t3^*@I&N+OfR{ zylc6O(2LnZrqn6onkFy4JXNE>wKuUITvK``Gb|pnB zEW_I|OpS$9Na5wXYmjg^LvI7X(h3&Hs*>>mfwngjd#igzulchq`A(M9DsqzPM(cr^ zSmVUnXnzESLi!MiHAVO{%BJ%Iv`W1Ax7<&A4zVdeOw;L~+PULJw_~HNeh7_@5j=Ty z7O!l_lL%m$CbmNi6qqTk42gLchh^9n7iuB4U=ggLml`NoY#j0+HP0o6s|^N=97eh; z>Ic!X`bL7`TFR;wCRfesY^iTR!Nl5Xsvc`_&D0)&;4arO^?{@wC9s6B>Vcx8D_v9? zN7}^hwA=rTI&zLJD+-&#HJ456(>u#Ax?ZR_P%qGch)@!wtQB#XYt?s!n0j;y$rB?` zKGOaZZ(pR``J%njKUOi%RdB}c9V%Z)h3bb1euRLnkMtFh#LkO?Ml;B$3@?_lLA}^IH#|L&Gg_AeInCwBg4O2%Bo3`sP_URNWhVbP!5M<#(%p+Fw>v%ckP|M@ zVQSfZ^7*;p%!j;M6b4qr!>JFr+*w43D(aQ6tjN=>7N$;{AF*E}jgv0@d^Pf&6xnM9viX?Nn-rf$h|It!E%6jhuz)H!S zg;Q~5lKkp+N8CX5HW(HR;P*>v{qh>TvEBA*zrSGIRrPZO<3uJ(PJs(nL?(&M!9idV zNAk~?N0utpoA%v9@p_f*-cnI|4{~3?E3yaPnSx?U1Mk@5q%kpM`*_^PuV-z+Xj&k& zL{h_ffYG_6fRnjS5Iu||zm1i1O&SvwJ1N4mIXH4oro`D9u6I?`fX8zx!U69fcfc>W zxI};f9sC6%BXmHeh({F^bhpK3hgKn8q{5X9=tMw=Yb@o`1=Pw3x=;b53Pw7~)9*lVA*}E`R^4|LW~15)L*qPd1^KDe~g-?IK;&tNQ4Mm@&?e~{n|fgt!m zky{{cKaO->H{oG7j=T{XvIH-~lOO8^rqadUXBA$Cp2cB2?+(K|^CRds@ZrU*)zYK2JDV?Iy;2jOrSzloRXQn(6FBXmQSaSpVE zqZ?_7uJ1I|qVDbiL7v5bPQlR)&0TXVXz)Nj2?Is##eqO9s$au+<>Hq5b%uU};6ntm z)W5~ha1VZ!QEry{ZGgUjC;8z15J*fP$;qw^w}vd^OSv^C5HP@WxFtLRx<@t#C#RM; zcCEa%GdH@}DJN3Dh-xFe>A!@~JX>aG9$L0un7KzW^#IH1@7MoEbmkt4qK>ec z1IqD$V^Ce1&f<*GnS&pHHxI^|Zxvd_827Nm(+~DO zQ9l#sN7_dXw1zZ&<~=OC=g2&!(ivesq{)3K;aJ<$jJXS;HN1E}{3~LZf>eCV3Z1k`dOxB!|rN?BLMZsi0!AxImG6wwO*TO zI596KWYe|~nj_b*A_ zzl%1XWb0(G&?b&4rhGO?TQ@X{a}56}ZhOVK(wOxnHWGowQFr)YPKul>Iz zJ4FwW`g4MWK=6AU$=~(AL2~t%Y~_|xi%9(DQMCSSV07X=z4ckM^7`7VSb4aaJv@=r z8sm>~emrkE3@?9)Q={k*=`f~O?acNEc{t-I5t4H~f;YSuK;IN^L7r-sEk+i#-;)pm zsVOFYYK)PRILY!>Ixg8~E|C^i< z(+zVT^TEZHnZZ0>l8Jp7)qZhgxATfApN%cowv~<6`VYecCUzQ;5zT~F2Q7$~u<6x4 zs;L*<;6Iqy;1Qc!eE_w`moxyTrxiBihe|;VPbe-+k5<^Oa^R&DVk)otYypnWFj}w1A_eyYMG= z|4qubn=*+XMCkmvBzw8uX|~(-ikoC{PjYj+9f>}k%mBIJ3$G1bC(EdzF9+*SGHSRJ zfR7;vYrN7wFm|Uf1JOD)qS9k$kxx3C2 zW~#6`>Ou^+#6QB*1u-*=ac+0occUrRh5a%U!+5QgqCD@`BQz7uqE6v%KNHivuZWVmiDawQxBIwO+k@Brf;L(Of(qFQ6@49mbgi3}{@cSwOAn zICI%hKNZiR%zQjg+sZU9{SL0-i?NJ_zUmDdmx4}4OK3BD%|hPM zc-r@l8N1WyVWx*V(#kjLRxpoa(m%fb==YD%vD$Inql9lew8ZVv362RZ+>x%~e9Q|; z@9=SQKX1Uvkf*5Aafrsk=aM`t zY6-`SEPj%eA~pCM1lRBDC!E$U3Xj$ftr@k&=`U$#;y-Y*u67lvk0PSzbxvlaiQyDO zy>K660a?+HGxRcoWrCXoe*;jR>VL%`m+|4_+g6GTss0VQNw!&Q2bXc}{J&V?lT7mT zu(47(9TI;?Q?Il7ca}RtAnZV_+ioP@9EM%?GNVnlPwvb-ntuFs%dY~+ z7$3nK<}pxTJ1t_I8-HLN zxBTgBqwbHcLJsDJ+L`X-Myd(nROHqYmXhxy7(sfFPQ-IRJ@b%A!~?>XW8-s} zuqAthz=62Uj)_}-K{|T=QBa?Y7Su-wv}{#>jWBNB2=PD1^p61~#kJ<;poR~0_)$XQ z!A-v|%ws_5136kk@uPSLa`ZV&Xhe=qi5xBcPvvNgQ+1rA7XiWdqryOHQugsg6|zaB z;cSfCh5eK{1CpuiJGkprj?1f*&c$e|RccKm1>bYY@?=FG&qSpjKBh$pO3Yc2c$Apl z4b*CtYxB`UJWF|m@R37HETV-aY2j#eEIQ6HT!VuP1pMCJIi0sVpZ+h1%$}1GACa;c z%OXApLh08)_RjYx6=&lkXeAwm$n7J$vzQ6wHl=X9fKoCl=Dp3lbW+b@m(qEub~~j0 zQtd=M2c|%&{r2b%v`DGFR9lYAK8`?lalN;7w09@;n58|C%B4?4FVnNE-IcD>EWEIz z>vS@ev+8e&6xdTk{T%_RRz#?_f+%j)-vd1u6 zMUiTD2fl9HDj%^;C`dJTc%C7ldO9;T>CFA& zHWCkC(aM$L1lNP{!EYoYQnL`@9!H1LQP~{gKr||c(E*)IO2gr#282>AYLTcvPZK$) zvTP(nAWUdDJdbI@lkjEup_1<$dD1hO>;cI~IR-9OrmM*rx);6dtz}vnJd1LX|3j(B zPc8Z~l~Qp=I9%UdxifQCe{pxX%}U8fUFh|ALv9K#E0EaW&KK7`E;eWfWeZn`p+-~A zknrf3bOt$YKPSTdb>!w!0_#9tGM@LLN(&J^EHV=(NsyUEEi*-4I#9TAvt80lsl0sS zFh!T-ZN^J)#F-QbM?K!p?M>^ETw=MeZFY%BHyW69Q4tVL~95SWYl*T=;3K;wIn}0*o8n{Ywzh zpE^CrkO)DMdP5=TF++f8S(JY_hVsv|9h+`U$~$7%5lF!!z>PBQ1v}}vg}4Hy``Gc4 zo&ZXsN7)0-QjLEu^s<@Pij2ZV>H^a7H_yKlnlC(ouzDQ;j?lI@x2~#pGwk8}lg9d~ zPDgwi)wzE3)bNN{yuIq1S$ZBSeLFm*zh=d^Ex1h;h6)2QlubZErWwLp7`h!GDGk1T z)un|YonZF25qvw#xZR5_eD^3<6+}lL>hZ#MM5=4~J6!9*f+ybn`gZsO*i|EZ6#!{7 orB7?d#zhF+peg2EbN-un{l9$s`~u+Tlo!eml)n(aGC~*s8^hmF(EtDd literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/clip_embedder.cpython-310.pyc b/tools/modules/__pycache__/clip_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b534ca19978798620fe29c1490cce21f77b58a26 GIT binary patch literal 6728 zcmcIoNpl;=6`mb~!NOHsB-v)X#D-O(<0`u%JC2t~lqE-^9gA|65jtZCJtPMNW}u#d zM4|?Vn5y`Y(lOVVtjdKs&)xd zuiy7|hlPTp;CJi7Z&$vWQ}Gp;AyjIHs)n(&Wgsy*~A00oMV&h z3|jJEDuz-nK&=@ub(M#0FZi(O1($DJ|7f9c$79Unk1siM_}z}YXf9s!=6daOF|P#C zO2``?pYwuhh`w_u&c)tFEQbiAUWsCn!wXyW*l$!~@3H-fzjX21#aCYwMzzH%dp9w+ zrD~=-ReI!+Xg$e3(H|gUWlPyscNA4gtVF44_mn9Af%1hK>osHBl+ubbJL-m8?)>*NyW^Frbzid#%{$b4cz zjwB|(mMGh!JuTcZXfno(b&cOY7(wZ-#yl zN3Vo1id(JwO}DujhrGJxHekT|{2f0KdbQcw)r7^odw$guPLq4EwI6tlTd)ULNs!6* zC|8;kX4GnWT-dJb2Y&3j!n*O^rwbq75?UAuqgjd9M27i1s5Cs0soaTV4SKKmbq~Fb z8gvrbzDnX;m3vsMR3*0(Ft@+fkuasYd<35s);&LJRqDc&$_uAa*>LM#u!;{_E8O$i zo-pf`P3V@D57%cKVq!6juQ!@?ui*u;#}+sbc_}Z9$g8jLQF^UWfkEUK_#DdoJ0zZ^ zma%G!a}0Aa9j^*aHpOVOa;T=r!*HMne}+b??yf4bE^Jyz{3ftBA3wW(CWe}J#*_9_LuQb= zN0oH${V3qk&L{n&IfxkAC$z6M(mv6)u_J4S-06CvLw9wtWb$+9DeO30_W~adH=V{J zEMvU3@3(M44eml-TQbsF7!;_4B?m@AZ6BJC$RL>AuWpLRn>ZVTV2diST8Ag8Whe4%2P@6$8X|!Tb%#ei! zZRmDQ!mr>H!mbAhAsD->ON;Wis8-!LWJvxds^3P6iV)-&MHsS(CyY5i#F#@CagpD^ zJE~;X-jB@NH}M(@F&SB{X*J+9{Qc!+t)q90M2`{4d-cefGYTU0S)d9~x!Yp#3vs4q zZ)bOu=tAru_N$4urFYc3nKaK-EC!V@20Mq7tM2AR76=CzM9fSLo$JR-kC(I9K6mP2 zbv+U~f+K$pW6I;SpWJ;!z-sWP&>-}hAB(&nxlxt-&6rln+3&V%rj{lvq>G^XLgj}o zejUy4BR!&UI;GC2nrfJn|17FS4K{yeh|Pzs_5qJ3=b}V2mpeUy0k>cdehD>Fk#3lf z?r=ymVWkWE6BHLx@5bE~;L=j^X?9?gDj@L62e< zF_!2ogsm8wX2RD0B3XQ5KW&?~Ii4MLIQ6R{9r#ft{f0BV3>ecZN-&WuM{Ait`5-{u2q_J)XyD6Q#B#Zn3NF{e>!M zM;%^)Dy>NGMPXSQAqMywSkQ@l?I=exlPrRo)8%W<0AJIO@U>hg-zg+H@U3Fk^xIF@{Hto`Gk1f zRAO-tXfrLD9YKKO%`u9y}c*YC>b3YymXF_Q-$8WR^NQ{9`lb9hfOM-k&!X7_QIf@%8TcT~1 zMj$O{3L^|W@U<2Q=Ov*o31=z%==zQ8jmoMgvaKMx+aiWm&fQ=0LH595R>9!u4Sl{P zV-SiWWrtPb ziYt}ia%0JTCU+Ov!bOUxG)S2Y9XRsD&@J^@M7Lw9(+ej#`{@LVzk*Nl?~!AE1ESC8eeE;?dY1prwOT z$+jsEa^My?(i!`)ALkacW5H|81Wcw&uuJJtjz^?tKlrC)fDWpX&!W={l#fZsLz%YW z^B_*)Z3PG9QfAPN+PF;Z^ICtgrcF($?CH%|Tf{|&$Y0VSjn zY^{BIs7M!y{|tjl+TvnL=tlVGG?NO6J0u8xhaXbqmX}cXbK^R|pL&+Y`58P?7thqg)m6Hbkhg0lt~YRlxsGPKX~XU8d=Ll? zX-bxG4Ge#u@85bv!r9OB0K>|YI4abmdG4+HQOq~fdst3RDT-;O$J0t--IV8ePH{xi q5~<$64CP!YYzh0V^vwUZic)u)HuequX)on9OP^G&YeVE*l+z6on literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/clip_embedder.cpython-39.pyc b/tools/modules/__pycache__/clip_embedder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73c8c162659dbde0b8c113bcb5288cdc50569115 GIT binary patch literal 6645 zcmdT}&2JmW72nw}E|(N#OR_8{PC~><8@9F5q)FDi?%6HB0tj0R|7Am^e8q+5CDskfY4pywXvRG=sQ3kC}J_hw0oqDeVG)2|ig?R+eE z=e>FJe!n-W&ds?RuGj1@HnvY|+TW=$epzU|jS_tU!8OhrS|z=+3R7*pqN}@6G1T3x zSn6&!GnEWWpLLqqimU2+BiGDV@=WvdyK{Tmdm1;mc~9e}pQ#sij2%{)AB|Yt#)v%{ zIW`)}a0esKXyiC&*R@i1A9K{oB_^GVSAKl)!i9_1YSc4+kv?=q>kz(1I$-_`bJw}C zsaH(C{F%1FxWlveve3lkxqDj0=6ODcdxjVIJnjxZ>1XRMKmI$$7x>~mtCHhO`~*hw zpJ}F6nZs;W7h_Qkqm7nm z`l9HEwHBT&qFIdn?O6RpnvH4{%N$edlFF3x29y^}f z!PmHUUF&Dc(x|mNeO=nzza7+k>9&OrLj)k+0vMV+=Uq~m!(7{s0@ z?aS|cbn)7CnQg%^Uab+d<&!nx!#Z9QMr?T1kb5KJM$(E~qP8XG@H%PV4x&!AAuaW! zbeq*}ui=NA@s_kVgztBKX*H^M@x83Nc>&*+i{(~)so8G$O+So%eo=^)C}pJ?`Hc-x zpx2t!;TOaT+Ts+6CrF&6VdnLj%zBt>HHtj1uRS<-{*85bgoxIID5$sAKMaF+Lbbx{ zajVsc)-k_M!;jWUJ@+i(t=r7T-GOu=QS|OJ(>-!^ zXv!LyczEfkkk6b)iJpVNN-Dp>Dx$ut?`apbpSwNA;VYhTYe}uBUVgW4rdpc*ZCKWA zS4Fky$IwmA3IA3n5PsxsiYo8x;wijVTIvb$Eb6H^Lp3tlw@@NldyVDs*L`Zn{L_`H z(TsXuGaxh@b_!$de7*Yi4h~d4f|2ERy4%8mzj3!lR>{N*5XXv+0E{0 z(OT>xVzNZvF?#G)A+2+U#h@0ztFM>EtN20CO#vnu} zXPB6mw9{pM56?zw=g*E0>66&rgo}bcB?uMQ6vC8PxnCeS2@$3MC$2}ZfnbYz-t%q& z9tI;B&*QC{=ZOV+0|nkRu)j<-a?I4~_L?6?Eg{Ho;-?UDZV>8gexre(KT88j6sdNG z1o>bdB_h0VbkowA{twdvNIxtrITrr13$A*qh6fH^DnO!VHm5nMI8BzjiqqtH9!$mN zbG(3ihR^e3xI27Bahe5innk{(IL!&ZjFAIQb4s2*yx=z!dJdV5prD(gHqEK8@BfWD zJ|r<=`5G<4@FKlw&yfwHKS7K!6mX6ShGIOzP@G=2=OzvqiUVwNVoGGv!b<>yG zP8i+l5R0kg?ra4hM4%L#{v;=%AR@s2wNgHf2#UiLZi8JR+_V}Uit2*+MLK~gltw&* z(F&2bwrJJ-D5|);8dp8_RZ*m$Du+KqHByh9BKihIN?A;nf0#EwTiCN~Wg2yP{wT-$ z8VQUbrX!vwQ6h1$SJl*sl{8VtS;?6rr#6L~kPD13hKvc%QtSl0M61IWhfu5nAt;7g z?EC=6(j{EQ&r$}G;-$3G{+FnKV$v!;p+BPisp(6s``pn}{_l|w@jaUV_el`uOlm?w z_72LU=&yHQIZ~%t96mm^c;LcCYO2k59X0ykG>lwP80i8d>8TpZp^LT~ zT@g{bsd@j()V%+cLP<7;lE?W1Ffq%Y@LfR4BG8gDf2)$`%lsrp=F|y50Ve=Y%HrXL zSpx$;5-}%Daww0dgBntUKqdMW#289`l{}tMGOG|WGnT_U$lQ#j{x%Hb20!{>l=`612;npbJ=s{j{;vDw(;9$?Q*tdaIjWKo-uiiiqxreVu( z$wGu@!*qCRGF*Dx;{0b(o9TpQHTj~ zZr4K^vuml`59tiDK=S~Ckmy)F8?Is{_68E<6pYgBPk4B>7-5h~o5gpik0Nv_lP=}Q z)S>b?eLE1}MN8`2vC6r|_P5&F1W|uPi4-r?i#V@XX5AMaZ-XgNMJPA2!KY{pbx&cH zTK{^U9A|c&M6YMLeoCXmGHe-~SKenEXsa*ej={RHp1y&cx|A8qsRiY)D%;zChzn%U z;6W4IOYM(yH26A+H%O=?f-+?BCJB|+**H-NLGP?+tJ?M@nwvPcs?4iM5B^p6YsA?;NuQ-UT%JQ^eU!p|BJP#+-FR<>p$F9H(V0|RhGw}BF^c~{E2P-o4D?Qbo zPO?hrxsdo6^~u?i-~AaS8sN>@uR6b!9sM7H{a;8+BH|i$!dS}v`eQZzw*&rPM+=~L zA0MU$pg$w|*2_4sFPG-U6}(iGNqj)!DhZVWsbokUpei0q{+fm#l@a}nnxjlxT%+cJ zhZm`-+`Nj~=O_`K!cOGZi;nGPoSb8TJQJ6o!CUuJ*W%%pMzXNHr0q$g})g@vACO_0Sj_xSx7s5 zk#AabqS-i#2N_E7cZ~BL$Q$26kZ0RkO3pV>Fb0o#j literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/config.cpython-310.pyc b/tools/modules/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33c5510b640249dbddb300ccbd039d16c56a4d07 GIT binary patch literal 3845 zcmZ8jU5p&X5#E`dot>TC`@?60v5gl4HV3}*-C+zE<30P%h6Dcv`yiQIGMU@zog3f$ zX?pg~cf1fK5)TPUq$r6JSqcy-5mFwaJVfCo;w?e~1bK;Qi4YMbuu3BSLjnfp>)t*4 zk{Qj{)z#hARn=A1ll1l31w7w;@<*-DHx~+jk)4fCA39U``0iq%Kn2pmg41m^M^k-0 zDmsQHd3soiN=`}9#n6n(PFc}fXgL-ulR>2?HK)RCW<6CUlgdx)-_V>IE1d6ps{iSN zGr($WV7TB6`USF{(OAVDxNZ8)2I_W2-QKgOYoWxwv==x^JAh+K?mmU%iZh{bGAPg}jeTFEeKZUl zrwQOBY`hlj9n=rpPd4yQ>~!E64b)w#e>dF&`XE(-_hRgj8oQ5fLjN$_&)WOcGSHJ|-!73!6QtYV5*^HM1^ASTB`D)5U z1qOL;81P(FBJYB`l4Qh%oiUz-?L0_gXK*(1e70e&V}~9F-2LO@QD+e1Mc+)&YPC6- zGc&U%h1pDa5HtVMv+w=ko9fBe8~*BJFaLFCT{JNG-6xKIc+bSE4gUu#ujgaWz0&Z%{`$7c z+|OT?gZ3YPmfmN-(ePhb{@qWn--neyoc>n$)~Bz@!5t3W`^M>)Ag!}|!usWl4ZpiN zh|9lz?EPPz`F%q6hy_bFa&>14!=hh~z<2IKm^7ETsCs!GyULXuien0%ce@eY z>%GztWy+Ein~Q657Ntpslg2t0%Lo&jHu7+0GD<{$8%Ke|Z7!ur5a(h@_FE};VGjjr zT}`sMG{oaeqVGsMPgtz-?x7PMq*0;sIH0G5%(v_S`dxda(IZvlbJAWw5UdG%4ZC@9 z%tD^2psvEB>l|T9$2Av)j<4KYRIbZLAwvu*`>*B;&-O|7(6THA?h?x&5lXJHwVe@4 zLdf?mPmWHwcTP->vd9}vvv$mA3|7jc=*Bb=Y(zZ~6&5WZ0pz*bX~y!*Lt2s8ay(1c zSbREV@$`w=Gsom46dO?m14kM}s8MUXdb=fHy3H_1-E=iiGO4>%V{9Ra;SjZ9DfCuZ zmK}xO{tdr)`+>2`?=*tv_f!i0==Ao#zj?rJjOf`-a`g5mw!HT8$L}=!gWn#1^E<1@ z4*MD|oe$b^(00cs58O3AK0cP_ zi))8^Lp=B7xzKuHabcCu;LlamivN!w_rk>3d7i{0rEXH#C>-tz3;e}g*nWoaRq-ig z&bLvEuZERV<0OT0RE>Pm+4e}l+g^B1I{F-vMwPyEiK?E?=I-GAEN0RqQd}+y9fIj= z79u2V=X_uA+g<2BF@uttWSF+MP63xm7{u+Q&G**l{BU(|eIaaT8QL&esb7trz4a)O zu^lpwnQ%fLG?f5^Lp!@FEM`rJu7^Ri(34nic})g7Kn6r<$Z}TqvWSnI5+!x)q8ui^ zAH=>WCwxp=D-4Cjgr}LjYw$~l3ypB0Wiok$o^VlYwtP{P`DgkN883_eGz#P`5yGj6 zYGE|n3pcf@N3UGme?M z@=6m`zl{>mi#$u}tn4mJO6pYQOUh2kb?jb84%mwb9c<-C@N4&u>iI(6e9nNXdW&90 zvNTXCEL9o#6)K%Lkc~^d40$&=6{&|8BRE9OLy2^IHBwlry2!$TI^(&-n%-)+@OSTM zJQdd`sYILexf;9=Ra7do{S z1Y|pf@mTN3HQuT!Mxrb&fDWP(CP>y+mPC$qrdRNs+H$bALCUg9QCAg1XRy}`;ugcL z-DJ+DE2R3X>E?9}v42H`$^=xFOcJw;64SP_|pgH=QtvB zKH$OmWPUn{TB~!j?n)3-vVtH-j>}55J}}>GBkv^Ldhld>&N} zKP`N|Tj2i{@a-35LSXca82rd~_4^}(X8p)J{(sb0I-Z58UaZfUx%YX)BH zl5WaYQdZf*DoZmj=~Z3Lz?{$Y66oG*eP9kihHh#$c95T>@otc>>xQOZHgx@Sqo{vo z8u}%pq^S={1q_$BN{q1A{C#eS9L3JAO@}PFoDDWm41Kv#Iz+E%}yoI&_ zck4BKAZfdnxtAtACHq#|;Tf!YbvwSDv~#@(-afjCZdRlJ!)>tJsPb%%M$k^8okVL5 z+g+;VYVCfyg?3Yw_RwCsb-h6Qs6n^Ue!89RpgY%<_Mm`w(0<^ZbQkb0IsiNX-R{=9 z-3_X#sV1m0QdG9{u9p^CJ2YV)Igbw4qf4Ew$oX7V7RzJ92!}vpV zM0@aldYF#VBXn%NfYtNV0zNb>1<<(wIgDQn&mj58!*R_)|byed1lGEc-lVM&p}?Z@SzfG7{6^L6*L<7 zM+L*^o9OnDT?T&*Gy4XB8%4kcP#*Yr?pWZ@+~^%%Xcqeou#?&*ltJ~e3hAR2JK+kp zPzgiIl}cnu95hkP~Yx ziJX)r0S{f@5n&>8Jy{3@leG@N&d3pg>-(p??|oN2ueSPg*s@toM%2!xXm=T9K9i#x zloPv&8^9XohR(o7KN4a)Lzi z^s8^Tysx~syL$PD@2ElR*FT6KvEFZaulIiO{aX%U<+m5V=70FGn~*c`+cKYjXdKe_PBmR-Jlj-?lu_yVG!SxnSvCshUY4d^d`f;1qWPX=K& z?7iJ;P7Hil=Ai>aWkEq{RnxV06sGKYsyme^8xW@QN09l56?L3l<|0;1)#!PniV3$m z%YM{e5whl{Y3OJ#3t6EgiXljDOIZ*{3CqoZsY zVX!8xHSFewA@fC|!@35?Y;a^*IicN^d3=?2SGysZf($vX&AgE@A~^)MJ<(Iez*%7l zBtpqGwze-oQSil~-rUrzbN}qz6bsy`ILSgr)36d*IgDx|*ocZEt1MVX3aFE@;)JD% zi^L+i^HdV8vG9D%!o@SEFPv19P;O)m3>*&+P^H%P4|j{eblN_Top?2k5~aIRV|tl~ zaEMN$lJBmvBsmVfy<30u(PPtJeB9!%9jF@K)Z(5$eK2pen)&2*HG1!JJKuflUmv%; zBVV8S;2Wzak9i&a{IS(ne+~AY4J{oJvo`PWw#p}{<8da2IsAHSZP!q@t2_&N=FH5^ zKQJ>hGaaYhwWGr!k-F+!Xrr_^u*zfb=PIhj|3{E>eRle)h{9$`8)RE39nLkzy>2Qk zFF}CnFqA6mGgRXnVP)3^DWL*&Bkc}W?|zDl%=IrRZ(l+d=n8l#(nZwXF&KQ5g-qE+ znk%HyC)mGcB23~e6^HoG_oMsV5-MwyVA{b3aiOw`hgp<~gN@6czk0B->}N@W){m|= zZiLXmMi8k0_nE*%@Az4fjSF=8k)D2Dh>V>DlfgaHX|%g(}dpAW}-k)IK#;Jur@h z5poP+)L|)N74|@Josf2|$5Usc7>%mV4VT$9VhA4ux*NoHHGs3VW_ESJ$3D0%+@3BW zh!5Nm%<)->%T8bgA`PWQ?SO0>+>v&T2i`_UDT@4Bc9l-bVj-)Ww_bvyHl!{_#Cwjq zZtQaGsGCjJybN_<7=EVGS=(Jv6m+s{Vny4jxVAOyr~#{s(7{$-0KX2ds-Z8`1!#}y z(${qpQuTpi+tlSzY*O*;d@`f-Dynb6dTte)^=A1 z^?h(v=lw+12dOk1EG;wqM96}=$gOra3s(j)U}3HL>v5KaEV+K%b?pi5anv>k!NZcC z>)Z8f2*@mk@mO!GHSVgeNV1|VfDW?iN66St5(Tz-VOaI-dXKMdkgbASj}Lo1 z>@Zy3ZDw!ZB=ygYPHwx7*xwYPGXdo#kww8!Q^)Yff;>Zj#Mn4aUon#FaG!>M&n_Oh za#$k#aBYcRK{*OZ4VRjW9f`?0yB4Dp*%!QY(s&*BL9pO}++^mDViMe;yC c_6851wh7t~-%rhYIbZyuT*OzbQD5_SkK2 zb$e%b)K(rQAx|p^5C}>Hb_68LFW_e&@qpB?Jnb8AAYPX5RM~S)0&R7js#BLfb?SVl z&T+p~Dj0a)|J%n~ADlCce^BG-W25mB-uR~|gdteh@QCuR(PJKCn#X&L8_ejLo;Bsz zo{dpv*XiZF9C((nyKXP<1RL$*`sMGguU%fH2F}I-)S<|WTMGO zn*2eNKWXx@CZA|>Uz0yK8Et6#kd(BM3wYx?io_VR)R?dlOPOFR1~kAo8eGz(2UIk8 zmhTV3$ZvMrgUn6(J7Lt`3&D2cM7E<=<|tl#-;hO&|Bt-z!sg95l<{UeZg={d7y7-e zyEk9)ciWNZ@5cV1FO%SnZaC}MY!2h3-}9rs2;&*kzZtbxqjoPy!p)@L@5YZ5A+x+r1A>>SB-C@qOZD%27#VMI(H zuQN%<(8PP}A))AryC_1#W5V!wp5e#yrg|-L<{e{;30pXC&E3&v z?&U;9RMGE>ny7>3#k^PmEr><21UmN)^xze7{L5K&eWv5fTE!a|C5+s7K&ooyNv!EZ z+NJg!-uO9~OJeN9peBH8+ei(cC+66iP|ZH5IjMtoPQdJZH?>jblR`3=vQDv68kZ-= zJ|7iQOFotIaV4ov;0BUf%JZRQ-%zn|f2DR=OIpKF9GaOY6&EH7-iO(4cLa=tfFCgfQHO$4u+W%-bunpWObU^ z?zN`!ciVB2<-=&$3nlQcK%1G}Fe0GI=605g`bsxh{$PKyLi=@OI0%Czvt$s7Mh%J! zyIZn~E^mI&4VvM0zbo)9t+CjvO}9?!)!X!{S{67KElun#LXc0{;}>Yvr6laZpFw%h zPH&L-kg}C*dwIQI(e7o=&Tb%ESmv}lNOq#kY>PWkz$xJ*LCdSCGo}O3(s=WSj00tQ zMg28(rty8aDIcCDh-2SBbVqU->){SI;`1m9MTZygbGXACUgs7suqt0JXAe0w7q4=r@539k;7mKlN4Q!#>|J)v*fdAx$Ql9j z9X>WE1`z!hAlez_z_C(i#TeVE4ZOH$JA&C1;L&-xZ#6>3y*TvIAe?okx6Tf3P~|_I|waS;1A$o=1#8*Ns90(#c^qZarky{ ztCr+C<+P-|0Q$?LN?J@SD+W9Qd{bK8$2Ean+!)n>>l*m=w5;enX4W8geFa`8oqwCY zn;X~DIrtW}N8Ojs3w8!0sJ~?jF6?_8<1K2B7StO2*cvW+=hSM7|GD!Q(?#g14xh6? zUxU69XpF<>I2+2{D91xSr+}-s2`J)Ph5cB;op#)qhpX8R24Uui@OHZy$`u@KX18JT zv3!K;9B_OL75P=wv%4M2P!cZWqr|9Py>{j0YwOpqzw+9;fBnMhrS;70wWG|9!fx0L zBj7YfERfkiEAH_pXpY$o68R)*jV1XsG0Mue0ZLbOGJLpqNHFP8}*-U`xzmdry3pqK;2K<_gIT9;oda8cJ>Cie_qi6l&A1&`d3* zHWO-drfRdHHnf%6a!^zbYIBszax<05SFxkVKpUp%4*J{qSFjSK(G?QKIAEza7yN;=|)yL5yTv2sC%tpX)j~N!ENn|lzAXi zQEnpN+`;WT2$Jm+CoO+~ufI#1qEytyW~_&txexaiPi=`-KDG-wM8bmFg@;OlMSKS@ZGz$efeCYpi2M{0Q%vB%=2PofUoO=*o{hdQDJ|F=@^O3tzPQ*627b2> zN7(L1@(P&pD)rb?1$6s?(D+bDA&|@rfDA3ONEL_kB<)*86Pc|vc~U_)(ElFZnBt@| z`~$oNE2`J}G&RxpkL8pm<{hYsPJ`5xsH3LfpYkZBFsLqN3kZI3o$`B(6o})urs40` zG4a7eJPk*~{X1Y}76pU)wsCF;-C-DKW<2aQOnvNypo`FpV%gZMD~dEy_IzY%fegio zt!9rLkHKhYJRAFad@goSEMw~y-+TI$eSFU3WBQ=1nZj>nHl%$Ry-JF%8w#_qS#^CD zl)x07#|@?tOupxcms$vCiJ0D`8v93#QK@5r?S18xke@c=m~^zo18BaEhCuG0kc^!CAa8WZoze`rW)UCLmw=hiF-T1r=TZGKJI@6G3QQpeh`o-0o zmu~o1FRfl!|K62LYwOqi8&_Uk-;gh$Uw)YiIs%;=mh~D)y;Qc6IhWO6JWAFh$EtR& k{!wvX)Y|hdlDGc^y&MYbNYOz z&rwvXRT!ST_ujwp!NZLGlNtvf7mc%cvmc`nOz<8Hh*~|?=K<%Y&+2oFan`p3XVT*a zE@s(1uU`sEy3gtP{c=!7-R!NqI8C-tZlx`Ui;gcHJTb#^JY=q=A&9+%TJ1_e6|4U zHkkYb{dO_6#4uLR=f?9p?5B&I6_uinp6MUJ3pn^~+nlA1r&#uvC&4HEUd5Zwb?xp# zF+Dy($30#wTw(24!WPawYt;ChQI|#w*7b_&8J3=;S>pwiOOVUIL94xc7J$(xYM5ow ztfGOMOSPM?3xY~%RU4Gn@O79~O~yAj6B%b=M@HfTK;mQoi=2awTPKy9Z4O~&ei$Zc zl82$1Go}qonwYyt+PNfaQnf3QjQVjN%VqQ(oC#kP^+TGenzRr6(!OchH8cTmUuhsc8emJ^NR7d$_5* z_)Z?DLN#e+((g?8?jRfrRp)o&GQn!DT0356VVfk@S_faXXS7Y;)WJRJ0N)i`w`FCV<=d)VtG~tKaYE;H7y4Yv0?dE zgIBEzKV^Lx?@83QmX11u)4P_}reofH$c6cMwd?#fJV6Jpv&%lhmC)ty@~dpk9@!&j z1iW{xvAx5V**pBTA@J;tO6YS6Z;6fF!UbAKh_{7C^!cPoNF!z>npq8I+;Bs;-Il@z|F>y0XTweY)$2bSOs;9sVW?+Z?<@F>-B zZ3px4?Z4i9BrkRAMa}pCjMqogMYWh-V(OhscG*8;rK$DI;D5h{a?6# zzLABmGIj6a~D^X z-A_{Gr*SXt$0<;n4P>rdU={cFqqN3uN4adF)>@E{Q%7A}HsA^ShqF*MWAwXC{)XbI z{3d3~(=>hNAdd7>HGi;c4>D@{U^gHbOp_i1X6+#bsK`o7MPuw5+OjC0kqIooyD1jX~6>$lBWKW`C z?xHmZ_*VEFVC?ZaZvfN{YtgE5%W86hwU2U{pS64dd~5L-s3%*tb#O#m;-QS1(}$9h6*w-`E3-fvNj6&J?c1%BB<_}18$FfJCMY%T&7}$3W`ql zJ)xSP$Qs^G@^x)jbot!Dq>Mr}=6Sry8&Kr}-#YbqiaNabkSioNxTm(OXlS+B2h?^3 zYBR3TE*z~k8*1|=YIC7Bw6)qwP*e$O^R&uJyDF0}Vq5ur6fOUVQvPo>%p6n0TiDO2 zz};0rlW#rpzg2K7h^Bl@TMZqfNe>YHz)A%F&?@+jMY+}_O;FQE0+@px`c7*(#>Y5W zbjvs*Z63%lw38^eH*oU~qkR3yNz2Rl`nx2Q&K>~)=TnDUpKi?`6Nz>{Bn2HJVL?j~ zppwJtDZLRE;z6rGYzH50gW>>z344Nw<^&NFOyIyy70#itQeo~p7h}(9Ex4=ltM~#u zakU=}!`=XHFd3wh@GVzq#GNRhH;9D6heir%ch>-}pdL7+iv4wx5Ue8!bawe@VD6Zgh2zzmrTK71q z4~Oqe0^g^x@WH2eTAl&>v*=I`g@NX#@oq)E%{Wtbw%Kpl=Flrq4?!12v02d6G)xDw zA0i`*WGs#dsvbHVebLlxckb`-xs37~#uwo@x1Km=9bd5ckUl8arNCPo4cQz_UZ+FX zm_{rtmX1tSG{Y2^XDx26=a+(iT`RCc`v-1$QkZtThP)vfq-9K1%d(;H-&s#CzdCk>8$BYrc${|9-}r$|6~ z?Lj2*=fDw_E$%JBMT)zZ<^W4s_Ax$U3|vRK6<*;B*2C6T>kCwPcn9QYlN8t_!g3Z{ zwrZ4wrO2<4gjK@O{3U^UH4I^dJu{-?ZutU9{}L4!sUW{3348KYDt<)88z@w561lfy zSob7)*J6HyrqJyqn<&VLvag`<-AM*(^4NdZpYk8~Ykt*t{EF}Sww~Q;8d#BK2D3Va zI-RDcZ)~P*ax+m+RSwko1xc5eD(gP_vqh(KYHl-6df90icLKRTRU2Ob literal 0 HcmV?d00001 diff --git a/tools/modules/autoencoder.py b/tools/modules/autoencoder.py new file mode 100644 index 0000000..756d188 --- /dev/null +++ b/tools/modules/autoencoder.py @@ -0,0 +1,698 @@ +import os +import torch +import logging +import collections +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from ...utils.registry_class import AUTO_ENCODER,DISTRIBUTION + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.no_grad() +def get_first_stage_encoding(encoder_posterior, scale_factor=0.18215): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return scale_factor * z + + +@AUTO_ENCODER.register_class() +class AutoencoderKL(nn.Module): + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + use_vid_decoder=False, + **kwargs): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if pretrained is not None: + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/v2-1_512-ema-pruned.ckpt') + self.init_from_ckpt(pretrained, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + logging.info(f"Restored from {path}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def encode_firsr_stage(self, x, scale_factor=1.0): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + z = get_first_stage_encoding(posterior, scale_factor) + return z + + def encode_ms(self, x): + hs = self.encoder(x, True) + h = hs[-1] + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + hs[-1] = h + return hs + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z, **kwargs) + return dec + + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +@AUTO_ENCODER.register_class() +class AutoencoderVideo(AutoencoderKL): + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + use_vid_decoder=True, + learn_logvar=False, + **kwargs): + use_vid_decoder = True + super().__init__(ddconfig, embed_dim, pretrained, ignore_keys, image_key, colorize_nlabels, monitor, ema_decay, learn_logvar, use_vid_decoder, **kwargs) + + def decode(self, z, **kwargs): + # z = self.post_quant_conv(z) + dec = self.decoder(z, **kwargs) + return dec + + def encode(self, x): + h = self.encoder(x) + # moments = self.quant_conv(h) + moments = h + posterior = DiagonalGaussianDistribution(moments) + return posterior + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + + + +@DISTRIBUTION.register_class() +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +# -------------------------------modules-------------------------------- + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, return_feat=False): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if return_feat: + hs[-1] = h + return hs + else: + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels, curr_res, curr_res) + # logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z, **kwargs): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + + + diff --git a/tools/modules/clip_embedder.py b/tools/modules/clip_embedder.py new file mode 100644 index 0000000..cc711ce --- /dev/null +++ b/tools/modules/clip_embedder.py @@ -0,0 +1,241 @@ +import os +import torch +import logging +import open_clip +import numpy as np +import torch.nn as nn +import torchvision.transforms as T + +from ...utils.registry_class import EMBEDDER + + +@EMBEDDER.register_class() +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +@EMBEDDER.register_class() +class FrozenOpenCLIPVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + + del model.transformer + self.model = model + data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255 + self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0) + + self.device = device + self.max_length = max_length # 77 + if freeze: + self.freeze() + self.layer = layer # 'penultimate' + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image): + # tokens = open_clip.tokenize(text) + z = self.model.encode_image(image.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + + +@EMBEDDER.register_class() +class FrozenOpenCLIPTextVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last", **kwargs): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + + def forward(self, image=None, text=None): + + xi = self.model.encode_image(image.to(self.device)) if image is not None else None + tokens = open_clip.tokenize(text) + xt, x = self.encode_with_transformer(tokens.to(self.device)) + return xi, xt, x + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection + return xt, x + + + def encode_image(self, image): + return self.model.visual(image) + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + + return self(text) \ No newline at end of file diff --git a/tools/modules/config.py b/tools/modules/config.py new file mode 100644 index 0000000..9a8cc40 --- /dev/null +++ b/tools/modules/config.py @@ -0,0 +1,206 @@ +import torch +import logging +import os.path as osp +from datetime import datetime +from easydict import EasyDict +import os + +cfg = EasyDict(__name__='Config: VideoLDM Decoder') + +# -------------------------------distributed training-------------------------- +pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) +gpus_per_machine = torch.cuda.device_count() +world_size = pmi_world_size * gpus_per_machine +# ----------------------------------------------------------------------------- + + +# ---------------------------Dataset Parameter--------------------------------- +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] +cfg.max_words = 1000 +cfg.num_workers = 8 +cfg.prefetch_factor = 2 + +# PlaceHolder +cfg.resolution = [448, 256] +cfg.vit_out_dim = 1024 +cfg.vit_resolution = 336 +cfg.depth_clamp = 10.0 +cfg.misc_size = 384 +cfg.depth_std = 20.0 + +cfg.save_fps = 8 + +cfg.frame_lens = [32, 32, 32, 1] +cfg.sample_fps = [4, ] +cfg.vid_dataset = { + 'type': 'VideoBaseDataset', + 'data_list': [], + 'max_words': cfg.max_words, + 'resolution': cfg.resolution} +cfg.img_dataset = { + 'type': 'ImageBaseDataset', + 'data_list': ['laion_400m',], + 'max_words': cfg.max_words, + 'resolution': cfg.resolution} + +cfg.batch_sizes = { + str(1):256, + str(4):4, + str(8):4, + str(16):4} +# ----------------------------------------------------------------------------- + + +# ---------------------------Mode Parameters----------------------------------- +# Diffusion +cfg.Diffusion = { + 'type': 'DiffusionDDIM', + 'schedule': 'cosine', # cosine + 'schedule_param': { + 'num_timesteps': 1000, + 'cosine_s': 0.008, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', # [v, eps] + 'loss_type': 'mse', + 'var_type': 'fixed_small', + 'rescale_timesteps': False, + 'noise_strength': 0.1, + 'ddim_timesteps': 50 +} +cfg.ddim_timesteps = 50 # official: 250 +cfg.use_div_loss = False +# classifier-free guidance +cfg.p_zero = 0.9 +cfg.guide_scale = 3.0 + +# clip vision encoder +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] + +# sketch +cfg.sketch_mean = [0.485, 0.456, 0.406] +cfg.sketch_std = [0.229, 0.224, 0.225] +# cfg.misc_size = 256 +cfg.depth_std = 20.0 +cfg.depth_clamp = 10.0 +cfg.hist_sigma = 10.0 + +# Model +cfg.scale_factor = 0.18215 +cfg.use_checkpoint = True +cfg.use_sharded_ddp = False +cfg.use_fsdp = False +cfg.use_fp16 = True +cfg.temporal_attention = True + +cfg.UNet = { + 'type': 'UNetSD', + 'in_dim': 4, + 'dim': 320, + 'y_dim': cfg.vit_out_dim, + 'context_dim': 1024, + 'out_dim': 8, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'attn_scales': [1 / 1, 1 / 2, 1 / 4], + 'dropout': 0.1, + 'temporal_attention': cfg.temporal_attention, + 'temporal_attn_times': 1, + 'use_checkpoint': cfg.use_checkpoint, + 'use_fps_condition': False, + 'use_sim_mask': False +} + +# auotoencoder from stabel diffusion +cfg.guidances = [] +cfg.auto_encoder = { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'models/v2-1_512-ema-pruned.ckpt' +} +# clip embedder +cfg.embedder = { + 'type': 'FrozenOpenCLIPEmbedder', + 'layer': 'penultimate', + 'pretrained': 'models/open_clip_pytorch_model.bin' +} +# ----------------------------------------------------------------------------- + +# ---------------------------Training Settings--------------------------------- +# training and optimizer +cfg.ema_decay = 0.9999 +cfg.num_steps = 600000 +cfg.lr = 5e-5 +cfg.weight_decay = 0.0 +cfg.betas = (0.9, 0.999) +cfg.eps = 1.0e-8 +cfg.chunk_size = 16 +cfg.decoder_bs = 8 +cfg.alpha = 0.7 +cfg.save_ckp_interval = 1000 + +# scheduler +cfg.warmup_steps = 10 +cfg.decay_mode = 'cosine' + +# acceleration +cfg.use_ema = True +if world_size<2: + cfg.use_ema = False +cfg.load_from = None +# ----------------------------------------------------------------------------- + + +# ----------------------------Pretrain Settings--------------------------------- +cfg.Pretrain = { + 'type': 'pretrain_specific_strategies', + 'fix_weight': False, + 'grad_scale': 0.2, + 'resume_checkpoint': 'models/jiuniu_0267000.pth', + 'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json', +} +# ----------------------------------------------------------------------------- + + +# -----------------------------Visual------------------------------------------- +# Visual videos +cfg.viz_interval = 1000 +cfg.visual_train = { + 'type': 'VisualTrainTextImageToVideo', +} +cfg.visual_inference = { + 'type': 'VisualGeneratedVideos', +} +cfg.inference_list_path = '' + +# logging +cfg.log_interval = 100 + +### Default log_dir +cfg.log_dir = 'outputs/' +# ----------------------------------------------------------------------------- + + +# ---------------------------Others-------------------------------------------- +# seed +cfg.seed = 8888 +cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms' +# ----------------------------------------------------------------------------- + diff --git a/tools/modules/diffusions/__init__.py b/tools/modules/diffusions/__init__.py new file mode 100644 index 0000000..c025248 --- /dev/null +++ b/tools/modules/diffusions/__init__.py @@ -0,0 +1 @@ +from .diffusion_ddim import * diff --git a/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc b/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d69df26e7312c8db0882fdae951239163708e2 GIT binary patch literal 240 zcmYjMy$ZrW48GGvM12M4I`{x0qI7a`(Jpcj)U*)kT}pFB++Cf024Bh5$yadlswjSt zgzpEEaJSo50jKjN>xX%-$v+x_OhZX0G*qZ_wNltvs;K$UL5lQE4;GPL_mHxz5@H94 zWUs1h2K(KfJ|3Aw(ozfw?b*a^vtu%c13Bb_K*6QoZ1ePMm|Z2FEjY$hz9J-!W{&K} jr^YkqVk~2ohrn;M7yxqQ0fxo#)*62Tllo4nAL*zUeIZ3S literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc b/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1daaa1f05f3cc9e86879defe5259534f6c18524b GIT binary patch literal 184 zcmYe~<>g`kf|lOtY4Sk&F^Gc(44TX@8G%BYjJFuI z{4^P(_);>{(n^an^Yh|UQZjQ_G88cbrNP86J7=qy(Bjmh;+V|h%&h#F(7a5?yv*Fh wlGK=z{QR8anB4r7(wx-d7`Pe5G4b)4d6^~g@p=W7w>WGd3hY2OegHq)$ literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1604c918835946075a50b53c4a0f03930ef4fd51 GIT binary patch literal 24748 zcmd6P3!Ge6dEeZ}%-nhG?Ck99d$k&^o@-08WWX3BY-7s?1C}3xKo&WZ)t=ea?9R@t z&YhLyxk~8fr#inC;52G2P#cPpTfo_tfLF z&G?k+MSO2PexMnjR(*)?tH%$jjOs@_L$PekkY^xf#s*bR4HbsfaG|F#JQ7kP2#?mo zV+fDq*-;poOQ{L9^HO+vR2e5jQ@gHmmQGx?0jCT~jg4o&FidK@NlDlv?p2JUb zdNvKs-U*lr`w>+*SHu*U)rxPsm;m1{tKyPxpDPy5PuYGXXw5gOzHz}1yJu#W3!5Q5 zh$^02z5lLLZ*dFSJympz3zbv%SC;0Ue9J@m^To2NoOkof6IlRISjH=ZVToWI(K;aN#D#D6X=8M03@xX zF=}KDJ&c!GJRLt1Z};N4mIWw;uyaDwn0I9Y8!_jcS-*;1QJ*q+{ zL!0)AU^xuG?{#(TYjz0O8WH?;r+b%mK^13-bNT9%%Y`Y&?`=*}fVX-a{nHTu-)Y1D z`cB9%&*+&YKeki=db0+D%f)IwKY%wk1(3wz*oJdGYuZNg`s#3VaKUUJZWK7NY5Pt-U!GYi z)~ZMhDsl(3(!YpVhQ#kPz2M9|YLJz^#q9EtABZbuS+G{U95yeNzmP^N{1H zh>FfR!aoe`q{Awql9!BW8=)SRLMW_yRT{f2qWV+@|Dy$GJ+=`-4e{V_F0A@h_L2cU z6W=CnKn+S-qETYbRzqqSH7C1ss}b}isz%kAl<2uZ`ncL5>8Y;t5%h!{hCS)sCM~CS zNm{xqEv`o9B5G3Ymejs&Q}?L7lA5WHBBJ7GZJ*kYar6gqbCx=wrtqF^uK}^Qs9UA% zfI6rSVFm|P9Q5iozo)g{pjUpZR4f-}blz3IRakcYo)X^5;AMKIJYVpwCAaWw-=80P z_s4(e>}wu;7HcZM*;pfl|4BR~8zgeqzK9j~EM=Svxsf-7o-#aJnFuE(4AO>hhT$qm zUO?DL_G8#RnUM#%4EooHj{dQQbRr04lvoF(bWz=qM4f3i`J z8KTv(C!Y#P1= zW`6ZLes!Gut9#~m-}jN^#?xPV_2wQ;x(nPbmH}oi_;CmsWw$y*fh}CE=-D%%+m#i_ z7jG7UNN1QoI#;U9RPT7XAEpfG`c|n>_It`JOZjSXso++zJ%r@?&N*xt;eoMI#dU*V z6v@R>rA)pkV{Mi4Q{_q#jMA;@LV3P=#*fafEG_Gb@>4UVFBdLU!Eev%s^3qT?=0`{kVsUW$JDI2-N_cjR)3tb z>kYGvNDaC$TU_ogJ=RKEI@+C>S*~CLb+MxJm^_HcT#4Dz%+j(y(Mn>|jpY1nskppc zP=2Vv&zwb@k%}Ckft6@D>7Ll_wDLBf(6yF+AaQWSL}y z*X{&JCXzGj2rM{^C~Jn8aVo${73`G)AqgF>dTJRtGysZi~luh>GNy zeC}s+-vi*8Dzp;PQ&r0S*wZx#g$t%KKq03nHl7OCteRa5*CMq@HM9VUA5C^`wj!72N=uLLe|wmalIp-GlIEM3Z*=anO7Ps zNaswom6)YC)%Z0;E14@&Am~QmC(wD!ngvdRdBVtzin+j51z9TJ!RKowRkO2OID3_;1p<14=fFk=3q;(H$tE)=DeY>s}>^4rWEpH z#^a$6gldqkRCvYEuZ5HaoLmRpfo;NsT5F4R*XzYt93k!68K=(Of*8W@I|(Pn&3A6)r4q_5~RC zKIS7u^25-63nf1(|MGQ(TQMcRT+p+Ha7Ylf`%4o7`IDSI52EYeEkM3#W324mMfgH~|+@{|Lb+2%aX;1VjMH ze~?rIqO-!tE>hFem~JrQ%JLQDxJKNJhizlf8a5`#hQ~TdOh<{1Ric2<5L(=lg;k0w z142t5=9z>5tf_+#IOPSfW7xeentHs7wSbyuE}HIysE1q1C~G7n1SGb#K4sU_pa3q| z>tPk%F!Tc;G$2hIMj5Lggp@f-MK8MGKz%Zl@g9g!UgACG9F@1I7uyoIC0AftZRCWy zh&ot$-h|>vB@{~K0#wLFQ-9EN7Ge;=Bp&$XfRskp<6e@aw0uVLImq{GTk@rCdDl_OKVd zGDT%Ytq5Zjb+tCAzeif^K~IzAekq?o`M;5V_162))6q+qq!-dExoCdK&_kY6>#HUd z``ibmRO^`ut*_Q!>n}ebxq2nn z!%|zeUOR<5);>^wcY=Ditz}YbNlPsiDK*$t%h0#4#ft#D4UF~~gsepyp&vlVUPPSx zE-_U^7lyse`Ur7weH1JYslX8nDnDj+eH`X2n|~X2EsOcy;So-F*+j_8Y?vFSm&NL4 z$}nBjTly)~w$lqEJq+PCxXk;-D{K~=>*YH-;=s7%)Q*CV{a&~Jr# z51c#LZMBrV+4a3ALf~*n;=(~3_Cv4N3#As6Zasn>wXKLLJHNNPuJ(FJB|&w1y7PEF zTicU7jUJ|(Zwd6kQII)A-RgssK^|jzeZw$pd+FiOnEdxZ|Hv~a($+G_6Bkjszm?m| zD7zndP!oDp13g<$u-^3~gzO%g6Co5Epwdw8nnLyf%X@t&12Ngl4u!k{FN09OQ2#-1 z(CeuS&0l^=NhUb_e)f~KtHr+AXsA%jRw-8D{EH))En(88A z{Y2dyM>C$5BNP#-W2y*)8ub?aqb&SMf-b2oAx#Drg+D;Ldk`QVnXnuqWs(Tk*AJL! zn4s+IS(|SubJ*Bt#H~@l2@2g^rj@q(X2q2N_3_6^q3lhz;F2UYc%n;G_IXOg8C+_W;G1T6T+I|0R=R8I3mX#rWeE7L*iw7kX~qp z>Ot1GxjL z08=s;ErgBYY`fBBW8{JwKm_VmZRDXv8H5y22GK$?-p-I<2uTL$${t3%k3d%OhZ*Wp zyF}v>MEG^4Hn+U&_36eor!_uc`&~%TFKgL2J`XlE{dBG|U)-LsvcLoTx`{2EG3*y} zVQk-B+h9Xy4IPRZZL0}9+{V`^K{dB=Q+;mR#8GZyo~4NzoxrBFMU7@C3^f`|4%?F+ zC{UOneztATx*FYcHXEqX^{q)9+;q0xja}RCtIU}SX8jI?x7~iN32$yX{S{_@>D@NG z%|y$to8WA7j=1XvK}73{FU+KMUk2fiH!8k`Vn$Jwh38rh6|Y4CdFZ{-%c0B0W%IIi z*}fdU9Kj(oT}PZ+{Bq=S^s<9PS{#PsJ>kXI92g}{8s4BA^XSvm$3O`&)fY|Bz8+|N z*i)Xl7KIFAQP)N8UZnRzBaC2|rpvFdm!%&1uyVNDs-}vKK+dS=cx62Y^};!Y+Mz*W zcQX%VBL&$=r5Hw?y;A2yy@azE)|V8_awyT)RwC0@q95UGy`}+{M*q`o{qMoPPtx2v zHTc}kNGOlK>EV-q?+eYvDChXs=o0pS1>e zA{b2{uNuG@&`vKTMtr`BB?M+{tc3^HF3DjBIWV(fYKI;SuG629mtD;rDe^g z#YAM?u#{`7tGymCEO~HTax?Am5FwDpVBV%!@X|sT4hvcov_&4j$b!@#QHhY;7F%tR zSut!05f-vZbpz+Yl_j{Q74kS_eL45;IASME$J$=_0rx!Ixs#VmOK#mRdH7hV0;8^b z_u;lmJJylAIgCfv@dF4{Lu)*IG1hQ!u!f@rdB=+JoofWW!jw47^_qco5t}4aaZCf= z3I2#x1P4Uk8rE?glmODc18F~|EJ*wInz<0~koLpEbE#{fJr8yt99Tp=*qyKv?eb5X zNSp}PWCk4$6lCFVGS6>h9QoUJ;x871YZ9z4_d0_mm|kZ56evbL?&i8m(l z!jU&{aVXBDaYSlgAGR})KQoj+4}AvJF;NGjRWYnt(_7ZSMUVWjJFucJf%w8M5N5=={r%NjlfBIR$!Ze}mE<~`UDZ3ulU=RAit;*nu1 zJs9svicUb~o@)O-JV#79ZtKzk+Un?#Cfjv@Yh>3=;UNYc0C@Tx=m3d%NRVCvg2io+ z48PjWA>qXb+NZ zL7F%~H>?F4)D1%9NopUkfu9S(K6)!OIqD#==5(9{Bvlv|p%6sN2#y)XUBBwESM|-| zBx1c!MdfVDSb^^ZgBosITd=38TyBh$!=P1L+2<{=cm*8ap$lJ8T>(4NjtWON|vj{YQBx#@t z;>&xDS=Pj(U=C_(bvI*$mv1Y;{@_v}icfO`>99J((-d?g<_?>k^C!I{X&KEiy| zRRo2CtZ>i3S(_ZAaTWr1gYbC>w}s7g8~(hHiCsvf7)R|3PO~dV1x1J*;*xokakn>L zZx^cOTV?nxSi~LKk^%%j#?kHH#UrXf)*Odd09SBEJ5lR}z-*6>?hye4Z2W3wVlY5XCaX<_DLQx=qv4ess-oocD$RwY~A&p;Q1Uo>Yak6FUysU<&CtyIC2+T%7yw$)-jg&}kwX`Ex6usLqTZnFTe z63ha;X+{^?NJqF>EiAeEJuJ&P_2c4X9=HP^Vm%})U`i`XoBeI1q{W4r1VnTnKe|-F zk=U%;ruC3d1(qSM`gD}H((=&ug;77nMA3^bAvEpaqNU7C7jf~&+mNoEYs)nM3F^Ie z0OjLhsK3^*8HehN8-8L@PuY9K9Ff7X*f{tJc#W|NABuLH?@LwQVI`tK!$(NcI0BFt zA)tUT85_Vj5C=&5V?a$0S1@Tk8$j0}l)M>Pg9AuP89>54j1XGfiF9&&89@`*NC_k0 zfLnFFZht@#kT!=jXFyL8+(NLApoJB(#v-`L28fm`c{=pzb!PiX0^u#c!O-sze3js9 z1kV%T@?uC+2GD|EZ)8ZMrPcr>B)s4dN>CVeUk~7fOv^Od|Kfiezdasz@QXO`Tj(1! zGCh&kj2e+iJSoKV<4F!1k=MX~us?k=(gzQSJ({)7;_3J$@pdmB_aCE2H*xRmX!G9L z(f?of&Uot#H_V1{+f8nmZFB!@9LiNAS3>Pj6LN2B>pe8)V?8@nPI7eLL>tH5uw80W z(zSZnBWYW2qK)Gw+FrF!Qn%hjW9oi&KvFwzqD`q=aPzEn6YW-Y5bxbL(GID@ zQg%SyrjA@{+(f(G@9WrjZ=&bpVJu!i4302#l;C!N2Sq{ZFl2D8xi3U4e>pS0g5Y%o zX9!xjC}_wn(2vce>0Yxsx~1)#xR#+8S#OD8iC~$atC@3%dw|WHN1)5UokKhv9QTP- z-QnL(E*Hi_a-uH~JV~%h@D75mhDhEYVM8?pwj0m4y;q=rfYrW};9UgoCg^H~l;W*y zbqFcLxmlNeeVMhd5xke+9};vm_I||W+1M2XR&U+9L^t%3@4$>cl#bz%nWXt% ze~{qA1V2R3)$BUr=Gp890^7}G!~eBI(da?u1RiPS#}Lr}gy5eN{4;{CR(=w3PqLMN zj=<{9t%GZ{B5dqxj#UsRL~))awl{vH=y5z29@oPjSmVJV4vDCuz=(o7ws38A;AV(> zwpSvM6no%m8>`w>UhqU)d3fIt@*&KrLm?a}LS_w0zt_0TZwo*!z2sCK95Fdwism4M zQWq^spI(oO9Rr(F@luS63~Gk?2k{BkSxdw0+gEkGw77iq(Q$D-0h3%tq-MB3LAHea zOczu!-#%JR!rhf_vi;3awwe@2RUVGWnCLh;heNc5UOZ`)fQfRT+UNDLjSm@59EaF9 zSQ|q5%w-e3PqO#k5POSyhtaLqDBT+e!N@>KSQ)yxI+wEN|b+WNMw4L3vGKkUfY-=0rppAI)}IphDp)8TouHVR2R zzCwfQT`;HNekmmUx(MhjMsK}{j3FsN`c;vIby{a=6y>%B>aJ` zd1Z~Z#2-T+&PLBZ)!qxvMcufIuWZ!avL|ksqv5!{69N*okiiYQ)SmGsa8uS6*X?>p zS#V&=R``wQ z@PD$rC^<(Y=Nq=<++Af&<8U$HvE`oX7)tIy%mj}hC2p^`n=`CNF}B^_Zf`G)>$bNW z4(NNtv8TPQS4dl9($=EXMHld$=$Tb_L7zYvV`;dcOAC|I<|KN%tDKb@#-)aLO5b+Y z`!U(k(_P4yQ#&NzC%v)SzUr8oK-?}mp|e(R9}L?Z8(qMgu{_3wd3(P&q36_2DgVp%eNLoX~Go zSk3i=up~5`(5EotmPdHVn}U=1esMyd!V2#fC-k7DJ5k%=rW5*pZ=W~m?eg}E6Z)k7 z7vMvp%wOHf&}RsmoD1K+;WeWw4nFhpjhIyPP5dSO-o|?cUlS?KEO+Sl;Fod@lG>4c zgk_%keWb>-_|RbJ!=lr-I4zW^yCM1e25*SRd2<_%*|@qLzDqR9UUYL}~Z zkq}&~uSc%0zG6GB+;JVRO^a{VBj?g5rp&;d@IJJ%WegNlIRl3)o0 zfv%6JY38nvNV_OEWdy3T2o@sK<}ujMVH7T&&?MX>#IqrfFXGG>ZrF!wlpmxIeL){^ z9t5Jzi^E_R0^9pCWy#v=(`YL=V$d@P!d(Ax)YF<-&CYxt9TD5eEikJ{H()G3j|+-h zyW^m=Ovzo1?u^{1J1WFmhj5jGDot)a6lzV?R5<9~JiE7L%tXVJc-(BG5y2e=7SWh# zB*-AJeu6t@vDiW=`6`W!>Y537o0-P3NWV&40FkYB*q5mKgIbH znfFx;iJJCVhPaF`eP3kiHvQAcgF#U!M9>6ICO(Ftd8hX$Uw$btJeU=Z;zPPoJlXB&|hGH|y+Y3l%NiU;s# zf6H_lmM!~R_BZ3b!#MXkXdJ=aoyq7Va;1$s@rU;p9QgFM-|I<6c&~@&UafmQNx9d< z6SVK@UQegTuqf+YcYE|NVIKBy_UA*bKo?pKaGF!T>^zvF`&GZF`{5c-UZu-p&jHG|`i2n3cJ-|>N$ zX=h#S_3$nXggSVrcIM#?o%VcePu?4Me6r0qdR(;b_zX1f_~4^E=wG&e$43PF<{h8G z<{h7*`W+wS_6D1Ge9)`f@J-$E8Ie0a_+SrjPI|*oWJbI}goZ@G9`#1OKAsoqU%~q5 ze*-Wzb~6;D>8{pxd!dt{=oN03UuU(lnPgjik*~KA6bZgWu!G<%!Iuesl;D#D+<9$` zjJD&r1uQzqKxE&|UsUP*CWPRv_+bZzv)L5tYy1WfAVn2)C7EoV3Z~$whD&2S4{Y;u zPrSrsl{9P_!DSVA5#X9-V7b^0+KT>7^i~eKrtB`-9xTOgBfW7)QKq|Vy|z1b*>0I- zT{RmMO86Cqev9DK1j2Jwk|x*Ku(RA#U>egZ_QC@E-{NBf%dKbT#-j#68Le zeFVBrgl^PeFiaZa8b6|96b~v2>Hh>UHK_j(|2mG1D7JAfWa|Gj0{XuY{1L$)6Z{E* z^iS9pxz!xI_NNG}-oAC_+uZ$H+DXGB7fz=gt+RV!Jkt7~A)x;&!G9xomf*h=v|7K4 zSC6g#IRdLkzr)tI`0)m=Ko78u&k+0%g1;d6p9FtN@K*$XP4K@6t`Yoig8xJCe+mAE z;F|=}({C{(ll448Pq4c$Ah4bFy=}5OFqBKzV0`*T)_0wNepcEdkcL|G!V|Dhu%R#l z+wF(8hPLj94C0xzh>r&m|wEBj=Gx`;%Xey=eU}JdTr^nDU42dJgcj9rOU&XR_ z5xknR6&VL-ls6Tx>7q5OAErr(QiQNw#p<`7O=!DBe*q8iVCBA<%ONQ9^k(ZM-@5bZ)CcwMlTI*ShhEDN#Z4IWU zi^lhW*7Eb1&=~BhRmGQQ@I;ued6)(R6($EBgGh;bFv8vqhpi`$!9C5wVOj_uD@9FN zsg2HQh?N+ArVj^b&JbG#^dkCxWtL4EUh!rz-NVN~c>woBwiXi)c$&pf2ZT?4BL_K< zGa4Mxv8QJrKD&>-<(+A+uI%}aH78;-MI?f0Ol8{wyq&L4G~&yo0Z@YoyE4+eO43C&uhU<>tGb7(X*;n&VLOGj>)K`z*2Og?NN2zxkZ~ zqGp+?=;0HGPqaLjAB3XyMs{3GC|P_)ujC%qg?WDO;AhRaAY2uwuFggs(U4Al{7}}rwRL};z!lVs9QaQbVZw1J= zspU(|;s~R;lVu;s#QOGWd<|X3(YBaPdg43BV1U7Q(rNDuxu*ae>?VHHGHqn(mTN|w zj1h_9iDrz@JNXwWH$tIzI)}|z!_dM1_BR2qM@)R8#4=xq*yi(a`Tdp|F`qM|sc-%r DqjVov literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c86b6eb08497c3fd8d2fd8f2730db5686bde3edf GIT binary patch literal 28560 zcmeHw37lP5dEedMzRa69`y%OSBug6G8cV_$#}T%`!USW>j>Xcx$~;DM-i+Sp&3mJB z-$?RsUucj)13|8V00}rSL&-8-vOyq}q$N$7(3B==lBVhH7NCivqzg@28VuO||G)e8 z8Oef@A8vo1`Q2~MJ?GqW&-(4(o!L|>Zs6zfD<7O!zh@X<=1u2M3^&JcxnBWLhBC`W zKKN_qP25{$Yst>rrlD-*l%1teKC~3hhvhz0jx0s<(WO{EhI>PW%kiZ|KCzU{Czn$B zl;oM^biPkTRP=EppDEhw{Ts%s4HZ-I#|;%PW-r;7%=|zrKcSMyPd4%gTlp!KMt-`H zKh(MEvAa2V@Ai{(c+jIc`Bkt)!5^q{J1ht8q*V3F>2$) zwCTr=zvjOCUVrj651qJbB5T_6Q>Tly*@8QFrl?lRML+G97M9xgtZ=qmP^I%ly-=*o z75zP`t@+W1 zmTRSIWwtzLvoDPw2Y-3pi_5(az%v#N&ssFj+IrM8)=j22%sTScEoE+)^QN+pXDfTu z7&JzVy6qXBIcBUop1onR{ssr^tlsmQiYi{r)ru9js&gxDsj`r(N_qbz;`ODwB$SI zOT`P*jvo%*=9@L&yy%DAGqcOZO+yc%iD!=Ad)KMgyG8AuD!HY_>Z#XPO0TGt(2?S) zTD4ksPc2nB|L!SOnx98^tCfyqLE$ws%a8e1p_D*tZWO?<6!nUM}FV&=DUs`f6ap@AL&|}5T4`pS~l`903R+JOJ?0M4bQ1r zp1EkxnbR;zXv7Tj@*<$0ZBYB(bnLC*`8@WaAMQp)0Nq!Vxi#23x!}SAe}4}&aKRrgA%1hD=p*eND-!gSPZO2&Y-9QTMvs41vAh}Q>6;_e?sEVrC<0ekmj3RwSEq!?X~Z5U`F9{kRS)PNd%+yt43Z&Nm;hNUdgd}7{FBWe^aCwptFG4v&( z#?^#8k-A3tq}n0n>7Md2^kk>nC1rivl;zZJDa-Vf#nsq+SWT%tQrf?5>0Wh{lx7>F z2&*{WwomQHI0k~ed0QP&)3_h(Yyr79t6SvRA$6-d2pl%7;=o~t{8W3zfy4Y*xl}36 z>Vm6$ySVK7sWR>=pj3LcvQYHxCAau=|DPTF!5{tLxtHJfG}csp=Exp2f09UqaS7u4 zb*#8&EAzbJChs@iW_pgYkRFvZ&>hlynXUm12FQ-6upY0OUi^}&oU@L80EdI=&{@-c zSnAqdyyh&1E?I6Ff5UPYQQ<3AJ?_PUKBKi5pFk?^fxZi2;v3TQ!fwuz^kitm?5RK1Y)5J&wv_J?a+5d#l4o5!^i1Sp;Vrd; zhwJ)Qtg>&OM|PyKP4q!#CFV;Pv1Q7@^~L6my99m)_-lq0lDg>5&6ZII`^mS?mi{0T{lQZP+TfkE2Nb&*7j3=x>7BHc)B%R ztSr>d_>sAlrDa`JetNdNd}h`yH1F-XS#W-p;)O!BQgrhJ-NgkE_Va#>#j8bK-CABL zUaWxzpVKvefUwa0yuU*-(R2Y*v*Pv^+h9ulNphywtTHAo=;B;yx%cUbcG1$|-oos1 z6%(jSRb9a3fdS@9%#~-Cmi?XWBEGs=T$n4DmY0jlA8eNkLo!z_&L7@3_sC{4poBd{ z@@Q06a?>D~5igMy?)bx(-Hh z5gaKndB%$f=MOm`U?PE{8!Wo1 ziM%(%`qlX3$AR+|L=%W5!TgXF_Ct~zRE!4mdqnbOQ~Ty=-<3-TYyi%$3X)nJY*>peK-M zthxU>3~Psgu&3+`O(X<<1?rSg7T^+{MD7=|G&JNVDna zDjFAgLH`Y=*l*+tyTwV+S%O7^HnS9Px918$@5vkmJePw_V`iS~-43maHflD1NKZLR z8R1bRKrf9&NNdeeVVqzoInq(=6z0|)snZx~JrwMRc}Oms*6P0FkSTTP^a`bdTy_2$ zBi66Pt1v)HNWq9tGWEBZ8}1plZ!N9(HjXQZL2zz+=b-OgK~8M*NS{SB4M~hA1na${ zd%b(=cddB@54YBwXSOjT7;j^pIYSmk&3j9Bd~-%5uO5;ocur8G4v`D|L0IO;T)l0( z`4OTgyP|e$@74$fb2N*N=+gvmB6u^vW~4n)AfL5hk6d8!MS{l&dUng$xk3+?YtEUA z+KoBeiy6b5?eCqlYdqP-qmjUoXOsQ(aU>cihVv!!;6*ngr<{7M9ixA2FUV>fo+aXvNfT1DAz8xemMo&4UF|R|MU9 zpBLlv;MMy0SJo!x#k@4$6IEd^F$lhA(C|`HlJPQLa-O&1=~=g^atkXZMdFh?z*JEkR84_)OKVaAdLFPX zTN<0!#bRe$Wwdph0>e{0c%-ADfOUBVrdv}BAxzAP2E+14_iV=ezUA%{B6wIYpssIK zKp}-A(HTD5$IML7pAOc-E?!*T><@O_53n9Fk{^OZTrB%Z`BP{J;RTk&my3F?SgDnu zXi9cYn4jnvuW#(icK zIDQcNmtF+ed|*zLG<*O*DO{X0*|0oYp1BT$hbDmc!iR#-n#qh*%7H z(RqjxNQJ%FVnkV2oO%LVGK%e%0HSPe_sB-{l!fh{JeTug*lPCb1J~PXxyIf*QWtz6 zyKp);GczLvha~fG@5*ml`fl_*Ae}Ky6-Pj5!_coly3!9I$9Lwz=m4&i%f15!azX0r z_n>Su6)da}n0NB=cM)(Q^$!ueo8SroajWJA0RIP0H9$HQX7-?(RCBsPh^xzd)+G;M zgyN=S4`YdTnyVAtD5mR)u2o{&EvBUg0tQ52f@zt;JQERsG<6|@Xyx}oj-d;?Wa&F= zSPO`G)>+FvD&pamXOul=hBPwVpI}&WkR_ zz=TN&c32e1G_oG|l0>GJKaiSHsrl|LHPbch)Wu|jocGnx>-7|J(hwX25})zX${B(1 zH$kME_A(Ic94`$PpifX|=eri=U2%EWZ%JG1XCHcIiQtNMk;YgW!fo(|l)N*A-X<#t zq=kgE@KxzsU!xzXuAU}QFQbyEXX!hjxXgv1Pk*f$|}QXDdG| z&!^@2Q+OV=AckkLgRBkvlwI$y57Y-L?~+=5QtM;V+F+w~GI?x$sBs?++TFI5_v43uBj?Pulunr0lZ}QVvqiS>(A59c?e8y*6a6D zUee3XJ5cImpSii0eYrNNJaN^+p=}G*ATku*^*zjeDZxR229Yw3={uP9GJ=;7kU`W! z4sK=Yb^;Q(>70HC)3iwdjgi}YE$<)XeH^m4uDKUVwKHujs9(pT(A=4o%2|Dxb%pDE zGgEm2GI;@8MZKkNvs!QE?J0ss30j=FABFNaAZ?K4XKS!7`1b5+*AI)NSqKQ}aX#}Z z0^-r>p&r_)x5KkUlfDh>haY;ZAb(c)M;~J|!qAbm^P}^6R`l+E0s^+!c*q|opo3`M z1e8@Kkl5Kb-5vz3FCqKGxLjg=!-<;X&bT>=-@T?|-fd=|wJht|a3~25n(`?5=s{8E z^6qBI1D^<6hlBj#=faRx9dj7@6h0iQof-sH9yM=8%kgb^@gC0PDrLbrQsm|jSF6QR zK4FrjgefXUY>Sfzkrn5tNo@n>GFj(uN@#?se;R5&Q1IZi)zfn4G|rwdgQJGzg@Qbs zhd!?l!Y#1fdKf2aTUJv}eqV1}o$U})0_&uD>v*ZH?~^*s9%fp13G^VUzMGNZ&LL|1mg%xz1b&BlpLA>e~7@^Mh5&`31udS- zPt(fS#f#TwVTFU7t%_QBFP?KhB)TP0o@&L6L_;k{b)KGJ7kJeDm?}ZBM)^fQ!OSNK zdbqcQyfV-%{2|I+Vqqg3w|AOROArB$_yJpvx&yR;yEG2@J{bO?ajs6 zR^x{PV+xn>-|t1D#eZWL*AvPj2QCT-DBN+1qvOVgJp!|GkjEXnVSy_HZ(dJ=$Fj@e zrEnd#f6}c>C#=j4E2?*smbgIg!J(ZnDTsP1DMad1?Ldgk22 zUIKJ4bfQhhpb~a!=hTvZANsnP=pZxiM@FzoWCwhJx1S;SAi<9ikO9`82AFP0Qe2O1 z2|;V|_0O@&&l9v+?@*v5mexPSqAbC7TzIBq-h~nWI%)>e5|+xbX9SOrnp4*5&TeMB z_rWePpF&~_Cx09_f|Ic}b(kP}aB@Th{x}4Fi2Nf2X2SrJ=543z_GYb3Z^3enh;E7uRgdHQ{8en z5VdO44QB*ms$9bkoqOo`t8xca5h`U6S}+@>xlZ8=f=IyOU?V)hG-+yp9@?1U)y#Vx zfuIbkTm#99CEit;|5Jz~3X;Tw% zI)Ov!h&atu2;ww|9F8a0eQ0=MU_U#KXG5G$ojVYS(~YAkr!l(iY3wy)$yzxNKgwF58!#%c09*7@O%<60OHChc8DiM`5gmQ5^RPFTNIqa?+v#4l*)L zpq9P}Oo*wzWC8c3Aobx)dDdD4Jcv!17q$CP-Uq2Lj8mGaytnZzRL zKBG}%xlseMq5gtG>yRRGx><+3kpgd|(oCbxK56sr#uL%AA;`FK1qB_R>F;2km(pf3MW5F#q0F&GdIGcJp90eCtvV9N|*OX7<+(tcz}E13jk~3b>L#p zT(j4lHDfKbW(oGS>io6W;A{l->Bnm(_6)G6XNWqVccK%4LK~~$!G%j|I6)1}Y>4uq z2UYBPXv0{;b9fIp#z5z?=1_Mc+-^vo>u9U99WNwxU^uyv_tBgX@MBPPxACgr&^B@< z4i&ZwW`i~kZ6nA}viw#6G$LfVO<3FV*AKChpkYCr#EP&CSC`;|RxH5WdOJ_=BytUF zr+qE&Wv=ql2vD6G+sIH3McCa}6ehH5eJ> z9xKLst`Q^)OY$(+YbMr3w394_!3NY56cVcl8i=}etm8Ux0r>qW`2Co&!S6e3)?&Pi z-w*LQI_uG>;X(TY(?!^W1_~?D$pLjR#fe^8C#%zDfMs>IVSuq1T$6xp5hVNzy!{jb zM=vbtrX8%wCzv5?+g_b`b0RMux}S?9>R%X4A$U3h)3Rt|4@@2%V zR|$flQWCg2LcrHt{jXRIQ)b#)|MOER)u%!xf zvv7iwExARjzAZ&+?OQ)BJFD#~@SAM-w+UoUg`4M`=A(QMnS#8$aaqH=K=}O4*o~Zp z_PhrNqGLnf!a2|39h3-0hcl@-%+-_%()bI(-`Z!UtwR*{d!&GlHoEx9E-B!r&~A%; zh(QX#KK&M?fW!iLNG}1_;!WTUML!SDpDD-yi;$OL9oL(i`|08prt+6Q9b`{28&H=kHf0Am9>9%D zzYeuXMPkAFY=pp*KzfjJ8{EXhbjMnBfZO6v+Lgdz!8vb0FMJA;9Ayw_bGj@7Nfm-d z$N*~@M*Dx_C1JCHSGbB@A-|J^j0k-?*Blz3G&Jr~@=nS%jecA<;Nl!tiN~WVdBqVP zN%{j#3SjGE}!F`^i7IB22VQ}*P&&`*_l3Jzt9 znKA01qNLT}fFyK-a-7ptNWVASEl_TWvKux-z`Niu!1C4Xg&o5z3!dwwA zf^~R=<4{AQ|E3$ssEUli#SJ=al*!1)8Y7Q2W@5Z?4~`GiEu)6L%;E&b6}w_~_fw6G zK=W0HKFsR}m7;52ihnE5aHMJ!`@q(3xRHH;zWl|vxf8~&wG*Vx=iLix{3+@)WTk7- z^~qq~Rb+rWZpR2HH^vprtgVl>W;WhE6Ya4R7fnx`fL}b31q_1UgFpJNCIJ1JVfs!2 z5-NR^;1B@nzLe>M1fuj3%6AWL=Ww-tsLcmn2z+~q^&%eiX;qg zm(^k=&?*V}h#rL#)@e0+_j}ODU!s<=y0@D?(_g2}DK)EOL}b+bSj&&^X+CTm`#Z;k zVq;{2(}at72@1qiFr$Bm$uASga9QW+Z*2aL-+l1(!}mOW@ZQNk|I!`f_n`5Yqu}|+ zssAfm5|sJhnEFcVmV?uO0q`prm*&90w{-Vtb7$b*h1zO4@dbz-utJQgP3%6pM(0h`Cqhwr@}5!NED z1XC`vEfXEZB-e4dXV3)VEUXM@(y-VSlIE!GAeF`aPEe_I$TBV1BBs9KILyHxrAi%Vx5jcSJMqev& zN(qixJlJrTSdUVtY@eViTs2^C%QKU5OyGJDx&ZdJxS8qLJ8k)_EqWFB0OhZ|y|i*z zwg-2HxMv<_-jUYr5y5G`U4j3C&3)6&#^`UNmEYms;BpuU?U;mzvFAQ8uPFD`c(Ew`A`@Pff#Wj$rrBip!$e7+tPO;#9;CsQ=8M1=7qYXXg! zpm(yMWO3qH7Mil4@WvKsusmRpBEpz7P2g}V0h<67rw4IQ49Y3gMIN>V{i4VcK;1#s z6n7!HOR?^H9|(O$g+b{184qzq*jLnvN*ySAu2tvUfgn)Mi>zlijN`^76P})A4$$ij z3yd<@WEkNGv8DPk3qxMthADqxG{Q#C!j#~At$79&57WiZ2!C(cexh{4fj!>yqTypg)pWwmsCGBBr)S{hxJkY5-Y1q@DjsOmVbsf}X)LexZMW@#HdX0B!iO7QS z5`;v36nI4g;g|@o+`xVVY5@N`fsN@Q)RNk>uc8jvk%#doFhP>Pz!schhLG~^Ehnr; z7_V>z55nghn6t+&)aw~IZ6bX2#O8?J$42)Pa593Rxc0V?MHZe#_J5>OEzq}5+|Trf z`M@QDPZRtS!RH8mjo^y}e@Z~|FL++x#MC>O*B*hSgiW{^1=PeyYXWcO;eIftw3CyBdbxZIDRKQ|D7 zcDNOQcI2-s0FB{i2sIl;*qel!Z4-z#38|`CE1`C%of5#cH6V@k*v>AMlN!AtX_E*N z+pVUgtT!Z$WqZ_KDcc&7Hi?k5o76rj-5Qd{(*5dyly--tO{<%`Leg$gw{8=Xc2Ld8 zvqK2;JJbqEdx_uQb?zRZ3*;?WyZ{%7{6l1~Zv(hbWTP%k2G^R$;UJp=&ek-EPd4xiD>cu1ck&;&i*B){2{&r{YHSuEn7GBHN#7~Fe$bDFc?+B(kv@AP z5w_WrO=tHmO+pH-Byh>hya|bx$k}qE0FOt{J+b%w-$o`$hY5a^pl3{E%O2yH zNcy@Q_jhew`sR3qC%&4)1A)}@Wu)e$0VMcCMko0TyR!JD`n4=srB= zqwrcr`2LkJgw7N^_+vGPdJtTZb{&Q>L)3(AX2gJh0>t0o>35o!`9TtX^(9)1!d@Nq z($p#=mA+(Cko8h3b`(6bikD;5iJ)a@k&&NZoAnHANd2{_mk|%LDf&ULC!pcaiU1C8 zN{Fct#p%;7wxIP|5+3yQy&q_$25U+2e5a*V#za5td06BZ`*3Ac0v3;xjU8whgt`_^ZZ4q=IH9yx?bD*=P8hk3${D=bn z4m}lGu!R63o{`eK+>_xVIZ;VM^j#hSiwVT;Ec2PmL{2!h4DS5PNGAcyb4FcG9;&@^13> zaE{eD#gwrkw3|~qq~87BM15avLhVG}ZpOE;Rc{|`WgOd7 zXDZ_v8;Bhtz<(#pP`R`x%CE3k0IVdD&Ig%kqLks3vc@hxF@QsP^ZiwC^@ z>(e|>>o>zT%lMWN#JAj{u$t?)LP6P#Z<)r7+aBRTZyNC{`z5|*8Y{eC;#-0@d1!5> z72mSo+viPryS@Dq-!i420!ka}efE|7m%6~ZU7{kVUL{-?Pz|hi+L8suG{0K#yABRB9 z2fhvw)F77*{UF0%L_hc;Qfac=Vas}M^`+ZU=dK9p4%J|*4RLlkF>M82%6Fr&E&ZqX zEB&8BLi#^mNAEk$(=aXoK{rOzvYItUqsw@1+U)V>(J#e7KZy~DhSN`ovs3|};94kN zy9eYVqz6+$4`5sd!sb^AKq>@w_GH?YRn;%Un*xiaz5_|v|K5dm+H)#h`An-TA?XC{ zZWj>yytOkFJ7hZUYIf$xwK}5=1~kL_5*!^V`iRkVj|r$T)w|I%KO_qB_6dbA8Bjw}<7;(6m>WJ!#+T0G zGc@>2Sy&gBi?g*3i}D7WqEsK)mM=3kLhxpSJi+}0ls$qkf6X#=nt)PZ+e+MFcK*98 zpd6=voZu%2CeY2F02Vc#hWgy?*P8U`xsHE@_ z9hl}ZzxhZBACkkj3z;KSAfH9>ABK-G?}GkG4AD0i z1rh3>;&?X+juFrX5q!Gg?=kiF3H|}W1p?V(VypOvyyb}lRul9^|I`4d$`QPTqRZI( z?tetWu$`9fI~t^(ah`MHmN^AqN$7Q~?6dJu)XE}wj{$s^nSIWBE(}vf9Cyrhz8MZ@ z?NJ!bErk6Yz;*N2ooGv+{&nYTP9FpOG6?azJ%kU_z}%5B@5JxzRusCUj$p#@cRZMI z8x=}VK%xF=Ou}3rr=syQfxjtS_u_ItfVl%J69g18ASiH1YlZ_MSO~5;a3tSIz+9h^ z-+l-&2$t2KlE4)kZc7Y9fnz$_=?pAHn4aXd0}BTrv{)X#(Dgilg>Y=|ZmY8$2H=6+ zgHwEW9fk#V)+5_4_st#-w(jWS)ebBiY6TYJi(ly9U?Z?lSpHUE;czRkaHJ7fh}zz8 zE3go~s*m1KVBweq7UFAR4A=HXA>@pC!$^$?>pkv`d;N4!&_9Fq(LW0?J#i!0XW<`)UBBM1p@An;A=adV(JNkCkcek36PWaY2?Mh zAGV_7@bQToi;mkyiz>k%fhu(_jfR^)d>u7rvR%obwsb;i_{Jhq6()D9Be*7Uod~%M zG#){O9IT;;?GE%C_hIwWJ2nk>1*W22Un4?}et`IxiZ{Sd*dP*?U$+oH6>NGQ5yXyO zNkcm8s}CYpt`8>ra7V11{zddwEbG(oKjYlD@OrQm{|M#H!01KR?WwdKk=SmTWnIs* z0UA(PGOyz3`&N^sitL`kxZ~DnZX+ ze+_vLbFgyy;1GJAY5wrwMsT6o7lXBKG+0ajGt5?(=L3ZW zo-L8KzJSF4PjOEA*D)ad8w8&x_~!(&H-*UjCT{1sEhmwY%Zp?}5H?-KkT!S56NYl1%@_%{T9Nbqk7e1d;R@b3x8mMpuO6x3pODRQ*(*!-}`cIMfVNRIPCPCLT5H`9Z zRkoqw9n-XxhLcr@;zF}8Myr7LV|>($(JGAQcrjY#J0Gp0ui}{LKO?w{;LiyT5&Rc| zw)UNBxH(*{A80qW2BX|SkjAg#U-S{Sk@hoRMU^fQg+}jau zFuiqqUBk`y#gK!}u!Qe^$brssv2huCfT{N}CGi5^@xTB*#1{+`j1u%PI1|V_;2{4m z(S_Al-tdkRwx$D-bvys-Nxo_a0l!M3_Ym~FncmfXc=NaB*53UA1i=K?*z*D7MI_$* zVVaj<2nsqNxz}`U*ZbJ#{R9WL>HRzK`0eaHrA8xgV&8@l+}@F0Pj|Q2KJW~B??)fq z??bCM;#5BVR>KLKQ*bMv-0D{TP{Xa9t#;YQ(DC*cxewsRgJtu>4(-IWv%=L2VH#7|!g8qKA+d&ZTn~sR? z4gW*{HkjaRRbr*+R`B5~8=w{Zj}G`m>>S0<<_~o!6!`bxaz78S2HEbcsmJLkwgv<0 zCG!p70QmnSpw)J2b`76q#T90~mRlIite7FV%Aq9U!G+*%L@_;j5iwgf+`h{W|B1oUA{xqb;jdymrk`Y0EN z4T&$nO$}QVc6;!T#B>K{XxVe8<9Z!W=zH1Rai+LY^f7{05WJG${V0HmQX*Bk;CMParQWe#*G-alVe%tB#cv8d{6@BD8TqP zqhT7VG;2>sW=_nUXb1S-2lK$I@vIpA{6TzqzU~YQgt+@v^H=`fIPF{X2n354@q=R=Ld`jNohpVR- z@mW1R%15RNZYJOXm*XK*>N}_LPsL;$9fz#7fPVBSOosS(V;rdAo&bOejsFD#{|90a lDTb_MW4I!?-fHl-(n`J6n2E-)*V?~dqnR$;xRCzZ{{e*lQ6B&R literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc b/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6267e7431b9170ab88bb6da345c62f816785f120 GIT binary patch literal 1337 zcmZ8g&5ImG6tA!D>6zUPyC!Z1gfQr7kOZ?x4x(0WDgDlO>b9CPtSCBb*j2& zC!yz{20Y~)yzO2D{{oM>dh%|aL$|Q4DJ>#%7qx?pcJMth+Ys`%Ir?CHVwJJu)TT2%KGON*(aERb zSz1VaX2VjODmu;7W@8-Jw$gc6XsPVR7@idAgCfnNN{uV6Gds?;tTS+s>10ydR2O!f zX=|0eQ$8}7D*|CrxJP3OY(pq-!E8Y~00BWla4L4rC9l{_yaix1L<+Eaz`YB$h6aIa z0rDSz@nA5CY&D3=(&!~71}O~mWKd1jAg=Q|1It1UpoEo4PwRGjQWx=6KU+nGjEo!< z+T>Aor#ApHSjAWPzu4eU?k6wlC;6|xe|LDvVSnfwe4i6=!qf@49#;F%RF0LY7|-7j0)>$+SU-Z$kQsE-j0nQXoSAD4 zn*q0bpOLR9CAdOoLAR}VMbBBou&-94;*DsyY~lNwuDp%j2RK3Z!B;fi%){UEltM%Y zEmAaopIos9jfa+17j}It%$zjniLyTzP0*gx{V(EzU+`70=|RqrH{^NRX4r!s?;Iz) zEtQAK-k-O>{Px+yj{q~Cr-iu&4^7_vX8-4(UjFrbBODGp&a=~~R4w$T15Pbqjzj-D zFV3R8bl%xi88zH;Z22D4MMZ_{;*Yg)lqtvJo~{#yQI>^a`-j^1wMSh!9)GoI+bQ^P{R=5flEUQm+n1PNCM%~#GN4!$a~VQbL0&(^WK{`@8dUbqMe-q0ebx8 z>EyW}QMNct5Sc_(m_ z9NL|E7Rq#~bfWTD*-vC@Vy&w5k&9i

m#6??L1vX zl^Rz?k(u$LkaecaSf-OnZPFq)%i_Dv#v;lpe@yX^!06j5}C@6%2F3AObk-+i-})Nl^@rOIs?s8 z`S2K(PVd+4@ube5m#VgWdQ!qvS664q(HsS@Vh>v4&xPtVP8e(Qw&<_dQ)Z z8+iy|g6x8=Xq=gYUH^o_jUH;GXxt$=qxBjGHS0c{y7(+}(x4^s?p!op>!<(Ujbnby z*MnvN8AH~P<VY1s&d62yS=ap}NeEQ&XK#VWaTyH~}$>kp|yngleuNNEU zVAQjYnMS2*f!94SYT>dR`q(;g7A;EaJe(@6Mq8GxK806NQEvPA^GqCN%CdMBcxk>2 zO(mCs2iS*YR-2%Q97GPb(e|DCLzIeK;Zbc`JsS1VK&q^YI^WQL_s#zvTH_nO2h&`F zg;1Y*berzN_<&vjUQm}Zdg&h_cIjy=>2qLcdtn%5Ss1oIsI9Mcb;dpZPTe+>ceDG0 QU*L~2sQ0MHU~td*7pzY^!~g&Q literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc b/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c052b4cc3a5d411f353da7e7449b4e694e9c84a1 GIT binary patch literal 4726 zcmZ`-&2t<_6`!83on3ud(ngjo8?Q~iHi|!E0!gJT6N5v5=->|sLaAZ2J(5P=-C1{! z@CUO%l}sg7AaV#*TsTNoTy)_tpbGv14&3NdfP+zSfSjmuQW3w`Gdt3*l~K)`?l(Q# zuiyK<-|JR9GEy;cJ=*;3a%IXe{z;vKmxaz%-15&5+~BNfEM!l%z{+OXD%<=lcewkA zEto~W#69#aUgniY#)4fQDi4=O%A@=oALB>x#^K|90&iSC$*Xvl_<26XkK&ETr};5x z%KSK=!L!0o@RN8B@eBM*{1iqE^NYO3Pop)$FY)L2^JtCIuhT6275)N$5t=bx=P%*A zBYc*>jORFKH;ww`U97^GS8gMgQ9IoI7nXOuL2K+?4jOcS2e*6^A~6z{n2E*N@9fk# zW~6MHMMi4kV(LhJ3vls)?S0`?Am#KeVqK z%a*vWu!1G2tX6Pe9TTB!1kKQY7>d|WLeYxaLDQFQQFoMA{OU-1z2zrSE0jstk!t*@ z(aODZFAytIIZI&@$Q?sG1N)z>UweJ=whV<_jAXPHFJ6mV%MWh9?ca;qJiaIWPArmO zsTuZOEH>6<61V(z%tP5z`nTKB*V|DmNW#S=j+=6^71R3VB9E4r*JTv9DCpMVUqGW5{=^0VoI?PB zu;0NCrFPe8+liyY?f@3EfPk}A0xT>I3mdSQO}x|rES%HN@xf|el27>y1wO|23M529XHClC9yaDXV&P)cqwL}D01jp;!7N$ksLr4>jqLa#c~n3Qak&qBc0+UtXU_ew7F%xR{*PvQAu z?_b7S&QfE>KpL@5IkmZ2HM%S{J7Z{DXgi6q#!_ff#A(gs_9ms9GrxjQL0cbQ)vj9k zN%aqJ{r%x@=60BJc=DhV?pVr6V$oPt9-bizGA6(kQd{1nK&MNUJMSsD@4!B+e#{)YLKymbr-_P8ymYbL#tePMpLC z;&~FMAa?8wi(e)M;p$Mc2;^6AOY&}wY0!VMNf$5(cmknoES0#~HFFbT4ooCj;E6lh zM0bAW$xqk*n0)l^+{*0xQ-Ax*|2~=%7x7C228&Ci8*m%tAC(gAHV@c`jih`0+0G|_ zeP`+Ca}6@QcWKa|LnGu8#EQ|i5}35XzG1v~hH8pV3Q!GaN1;c8xbGXwjKlOtUbmE5 zTc}vvPD`n~!5{&e8DrYu&W5?om~q!ltTj8eAW^qBO+=tHzsoRFgGBG5HbIjpK<9K~ zYypZN7)OzoBxkgfsvy7_V;i$;M_KcN5TZj^x0DS6UsV>^O_`mj&JCU~;jFEft3L4pH&4_tO{*%F{?kYzF?UhiuL8lY8 zIjUaN7B&2&a+9zvVM^W7$xmGhLX4Gv6M z6((&n%IiXB<~lpEwovN0l^WYt*WuK6aJwi1$oUl8i7K*>zDEgkQx{?Al|19)KkWl( z_j8IzNqhx!D?3Pnwz9(e9UUKW8cOk9dSl@QHkzz0&XbPl5!KPBdO#192nO8cYaaB< zXn)07!^b*iRaCrD@SM%W_n~W$n|c?;iwcBx3`K;w;<`vVne2uSuI>n_b&%ggqt7!=0uSkp zo3&wg-J)lshkBK|s46UVrj4{jSnqnN2Osomi)P1+7Bd`^=j-J`Ae@C(&^}DuhFBQN zAxxe*(7{&c5U_}s@yX9{Yc!8Kqd;?&7E1a)1|nmMg-)Q;OX(C9d>cs$k=fn!ILzlrmGPW~WPuoK?lZB`hl98l>XqU|D) z5hwTI`AQAfjqI3NyO~$-S`yb5aOAAr%8#72oAWpH$bHfhZ(%hW;5Z$L>(tR6a~jRK z9cGZH;)_i!6%mwpT~Smed-_W_Vvp^&eO)!j2NeDV^_1$?0PyvJ=tE2)=}_Z26)N)< zioJQn^sH%c1vcP;GB2PbUL;Y6K%ifthj~C{?YDawO|@H4tq?z? z(e}UyK|flb$k?w&!{`7jXj_pT_h>UdaOTr{09`nBiOT4c5=Q5&#z>% literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc b/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e87214ce4b29cf9cd9d02825ba07ca975b137a1 GIT binary patch literal 4653 zcmZ`+OK%(36~1>~9MO7MHY3ZHlhAJSU>rZ=q)lL0ZX7s$MT_{+lu*i!cqxf8MbdXh zjpGj60&?L1)L!2_uUzeL|PeQ&YXMZ z4(HtSo%5YL!Pr<);JE+bcZZHKL05Mb%Mf)r2|-i>oHpAy_yc9l;%6O{=4@jHqL32G^oGuAap8 zfI6?9Qcq*XsJfua>IAeg^^$rQ#FoAs zIkB_hZVF{hi*>2&+YW4A?5y*?gSR?!6({jpcwwNEk@Q*%RWfpCEl@g$nssgAO%HM6 z_|NRa(^rrG>+i1|o2?X*iT?EZazm}P8p)x)LDjOtF@@P+6bdYH%ulUo1J*K-wX6on79j# zCA}F$09l372QqR1 zrMnN50Pwb@&%{p8ZTpdH0CWNHEr9866acVo0CoWIt;mmE0PMyt0CwjPxKgXLgb>!U zxHsDfY*({JG?$?1jKNt?Q!7yT5S5YS$^9P91E5DZm8GZ;P>ez(6DH)fs8b7@ORGVs z$9QX;VgjP#>4VfKDGpH_rYJ!qu6e>>l?v3hNXCsde<^(ozYZb1Eb7u4jC;YCp5;pu zeZuqA(Z7hhO2%SFpde)TQ0yqHBziKox=%s3p}Ub-kufYWa9XjH^MHlp&Mnacapj{c zCR9s5E&bumzuo)o?3PSi6}{hWY}tt$b##3>@lh9HH>fv~!eXlvM6X;#47mJcq}gsp zHO33Y7lf&5;(<`G4yt}`WdiR8#YFxLOd)|1(wC$78S6GHcAMyKf3~x;v(2E_iF5xT z+{D=AALGvai$NSQVh~&F$e!zBOU_lSWHN<6+(k2aqPAa3|D}RhWerI;+)CW_4{qz4mp4I2^PeTst3$zV|jrGT-kcJKU;7e1J55L&@ z>@ROE{9?9l9LEeEOd6hiDqMhA5-YU$ zA+|Trdz2FwVsBkSB57vCv{3H4wJD{zGaA_|PHaP>*FLa-Kw)lM;!Smm{z09oNeVDE z9f=L3;;+OJlqKbi2~r6JI3qUkb`vOjPBTLcgcUn+K;X-X4R%YcZnGj2xrQqA`(0+Y zXcGTYD_C7n!R5E`_LVazut2?S94JQxh*)cr#L%zwRlhhm!xWg?AQ_Hx+iXp!exM}bz-2%XpTe4qT z+_EU>nT9S#ZAL61g)rNSnKpfgS+CpRvz>~Sx@*^i2tu?W^>DRV4QQs;Z;;!xNwWX*Yqkza$<7 z{W9L2I6)M&6T5M@Yw*!0VAL$M#KsM57-d_ZrH$lRb)y^Kewg2m8$m z9?mc;$&wtGWvOq$RySdS!h?9RJ($1{5mKA!Z8oV5Stt=CE4GPSNqmFdA6;F4A5&2c z{6DiM<9@}}*J;Ro*UxZxlVXsBlc_;pYvFQpF={A{sTnVi$|q9RYS|*uWagKtf_enCrj{IJ&WyA{nFI^a{+3xWLwfdFD!f zMUE8W!k$7SGi21d2gqx%g3cHTbarm+h+nP(Lt8I*s*Vm;yVV1@Rcp3?QlserHun92 zDTCptP23K&4#Iqc_a5F6aw{nSFSj)%c}(gbVLY&nVdm$B`8oyzk#QqL$ZXr}b*Ja$ zBL^dNtJp(Nk+C~1;sSHM=f^%`(C;jU9aCCN^O%rTMurpNG_0DtFnt4}dLT<-li9li zY;^WX7X2a~`8m!so856Nb{tDJ_5lW>GxJcqwx(f9bC=K(q{n$8rO7j;sMaZ^%vT8)935wl^&3;0|SHx=S|s zrn!$?L^hHpUrpyb$VWVnJa!HFTlgZ+$RETCcEUZ}O&bH-0b2)&wueFnPTsw<#WIe! z)6d27^{jiBqfWVsPsj3&?9;J)eeP{DbCd>iREIWQ4(>#a_^k*552V8xb} zO$068bd+i_z19{2U7gs~YwViU$st4S+~cMjLy2!bhd#n9LVnT+*Ox`<+emoJw+`C; zewiZebN>!#*@v_Zacywmq%G$;-hQ5<0s+)t;>FryvUl}bB~dJ`Eg%9^!aBt1g49cewA`5tH%_uNp8(#Gt1f&5W}SAbI-q&euZCV v-iO>lU^8nAiQkvS!`XGp_l+jS=f$G$_+D`m()Zn>Uv!Hj*}X4{qrUTh`Ek#) literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/diffusion_ddim.py b/tools/modules/diffusions/diffusion_ddim.py new file mode 100644 index 0000000..43a17d3 --- /dev/null +++ b/tools/modules/diffusions/diffusion_ddim.py @@ -0,0 +1,1121 @@ +import torch +import math + +from ....utils.registry_class import DIFFUSION +from .schedules import beta_schedule, sigma_schedule +from .losses import kl_divergence, discretized_gaussian_log_likelihood +# from .dpm_solver import NoiseScheduleVP, model_wrapper_guided_diffusion, model_wrapper, DPM_Solver +from typing import Callable, List, Optional +import numpy as np + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + if tensor.device != x.device: + tensor = tensor.to(x.device) + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + +@DIFFUSION.register_class() +class DiffusionDDIMSR(object): + def __init__(self, reverse_diffusion, forward_diffusion, **kwargs): + from .diffusion_gauss import GaussianDiffusion + self.reverse_diffusion = GaussianDiffusion(sigmas=sigma_schedule(reverse_diffusion.schedule, **reverse_diffusion.schedule_param), + prediction_type=reverse_diffusion.mean_type) + self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param), + prediction_type=forward_diffusion.mean_type) + + +@DIFFUSION.register_class() +class DiffusionDPM(object): + def __init__(self, forward_diffusion, **kwargs): + from .diffusion_gauss import GaussianDiffusion + self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param), + prediction_type=forward_diffusion.mean_type) + + +@DIFFUSION.register_class() +class DiffusionDDIM(object): + def __init__(self, + schedule='linear_sd', + schedule_param={}, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon = 1e-12, + rescale_timesteps=False, + noise_strength=0.0, + **kwargs): + + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small'] + assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] + + betas = beta_schedule(schedule, **schedule_param) + assert min(betas) > 0 and max(betas) <= 1 + + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # eps + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + self.noise_strength = noise_strength # 0.0 + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _= x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2 + out = torch.cat([ + u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]), + y_out[:, dim:]], dim=1) # guide_scale=9.0 + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + from tqdm import tqdm + for step in tqdm(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta) + # from ipdb import set_trace; set_trace() + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None): + + # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32] + noise = self.sample_loss(x0, noise) + + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], + 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type] + if loss_mask is not None: + loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same) + loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w + # use masked diffusion + loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + else: + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss*weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + # # derive xt_1, set eta=0 as ddim + # alphas_prev = _i(self.alphas_cumprod, (t - 1).clamp(0), xt) + # direction = torch.sqrt(1 - alphas_prev) * out + # xt_1 = torch.sqrt(alphas_prev) * x0_ + direction + + # ncfhw, std on f + div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4) + # print(div_loss,loss) + loss = loss+div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss*weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + # noise = torch.randn_like(x0) + noise = self.sample_loss(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t + #return t.float() + + + + + + +@DIFFUSION.register_class() +class DiffusionDDIMLong(object): + def __init__(self, + schedule='linear_sd', + schedule_param={}, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon = 1e-12, + rescale_timesteps=False, + noise_strength=0.0, + **kwargs): + + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small'] + assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] + + betas = beta_schedule(schedule, **schedule_param) + assert min(betas) > 0 and max(betas) <= 1 + + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # v + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + self.noise_strength = noise_strength + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _= x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1): + r"""Distribution of p(x_{t-1} | x_t). + """ + noise = xt + context_queue = list( + context_scheduler( + 0, + 31, + noise.shape[2], + context_size=context_size, + context_stride=1, + context_overlap=4, + ) + ) + context_step = min( + context_stride, int(np.ceil(np.log2(noise.shape[2] / context_size))) + 1 + ) + # replace the final segment to improve temporal consistency + num_frames = noise.shape[2] + context_queue[-1] = [ + e % num_frames + for e in range(num_frames - context_size * context_step, num_frames, context_step) + ] + + import math + # context_batch_size = 1 + num_context_batches = math.ceil(len(context_queue) / context_batch_size) + global_context = [] + for i in range(num_context_batches): + global_context.append( + context_queue[ + i * context_batch_size : (i + 1) * context_batch_size + ] + ) + noise_pred = torch.zeros_like(noise) + noise_pred_uncond = torch.zeros_like(noise) + counter = torch.zeros( + (1, 1, xt.shape[2], 1, 1), + device=xt.device, + dtype=xt.dtype, + ) + + for i_index, context in enumerate(global_context): + + + latent_model_input = torch.cat([xt[:, :, c] for c in context]) + bs_context = len(context) + + model_kwargs_new = [{ + 'y': None, + "local_image": None if not model_kwargs[0].__contains__('local_image') else torch.cat([model_kwargs[0]["local_image"][:, :, c] for c in context]), + 'image': None if not model_kwargs[0].__contains__('image') else model_kwargs[0]["image"].repeat(bs_context, 1, 1), + 'dwpose': None if not model_kwargs[0].__contains__('dwpose') else torch.cat([model_kwargs[0]["dwpose"][:, :, [0]+[ii+1 for ii in c]] for c in context]), + 'randomref': None if not model_kwargs[0].__contains__('randomref') else torch.cat([model_kwargs[0]["randomref"][:, :, c] for c in context]), + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + if guide_scale is None: + out = model(latent_model_input, self._scale_timesteps(t), **model_kwargs) + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + out + counter[:, :, c] = counter[:, :, c] + 1 + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + # assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[0]) + u_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[1]) + dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2 + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + y_out[j:j+1] + noise_pred_uncond[:, :, c] = noise_pred_uncond[:, :, c] + u_out[j:j+1] + counter[:, :, c] = counter[:, :, c] + 1 + + noise_pred = noise_pred / counter + noise_pred_uncond = noise_pred_uncond / counter + out = torch.cat([ + noise_pred_uncond[:, :dim] + guide_scale * (noise_pred[:, :dim] - noise_pred_uncond[:, :dim]), + noise_pred[:, dim:]], dim=1) # guide_scale=2.5 + + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale, context_size, context_stride, context_overlap, context_batch_size) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, noise, context_size, context_stride, context_overlap, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_batch_size=1): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + from tqdm import tqdm + + for step in tqdm(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta, context_size=context_size, context_stride=context_stride, context_overlap=context_overlap, context_batch_size=context_batch_size) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None): + + # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32] + noise = self.sample_loss(x0, noise) + + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], + 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type] + if loss_mask is not None: + loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same) + loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w + # use masked diffusion + loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + else: + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss*weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + + # ncfhw, std on f + div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4) + # print(div_loss,loss) + loss = loss+div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss*weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + # noise = torch.randn_like(x0) + noise = self.sample_loss(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t + #return t.float() + + + +def ordered_halving(val): + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + + return as_int / (1 << 64) + + +def context_scheduler( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = False, +): + if num_frames <= context_size: + yield list(range(num_frames)) + return + + context_stride = min( + context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 + ) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + + yield [ + e % num_frames + for e in range(j, j + context_size * context_step, context_step) + ] + diff --git a/tools/modules/diffusions/diffusion_gauss.py b/tools/modules/diffusions/diffusion_gauss.py new file mode 100644 index 0000000..430ab3d --- /dev/null +++ b/tools/modules/diffusions/diffusion_gauss.py @@ -0,0 +1,498 @@ +""" +GaussianDiffusion wraps operators for denoising diffusion models, including the +diffusion and denoising processes, as well as the loss evaluation. +""" +import torch +import torchsde +import random +from tqdm.auto import trange + + +__all__ = ['GaussianDiffusion'] + + +def _i(tensor, t, x): + """ + Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t.to(tensor.device)].view(shape).to(x.device) + + +class BatchedBrownianTree: + """ + A wrapper around torchsde.BrownianTree that enables batches of entropy. + """ + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree( + t0, w0, t1, entropy=s, **kwargs + ) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """ + A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0 = self.transform(torch.as_tensor(sigma_min)) + t1 = self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0 = self.transform(torch.as_tensor(sigma)) + t1 = self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +def get_scalings(sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5 + return c_out, c_in + + +@torch.no_grad() +def sample_dpmpp_2m_sde( + noise, + model, + sigmas, + eta=1., + s_noise=1., + solver_type='midpoint', + show_progress=True +): + """ + DPM-Solver++ (2M) SDE. + """ + assert solver_type in {'heun', 'midpoint'} + + x = noise * sigmas[0] + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=not show_progress): + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigmas[i]) + x = denoised + sigmas[i + 1] * noise + else: + _, c_in = get_scalings(sigmas[i]) + denoised = model(x * c_in, sigmas[i]) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ + (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * \ + (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler( + sigmas[i], + sigmas[i + 1] + ) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x + + +class GaussianDiffusion(object): + + def __init__(self, sigmas, prediction_type='eps'): + assert prediction_type in {'x0', 'eps', 'v'} + self.sigmas = sigmas.float() # noise coefficients + self.alphas = torch.sqrt(1 - sigmas ** 2).float() # signal coefficients + self.num_timesteps = len(sigmas) + self.prediction_type = prediction_type + + def diffuse(self, x0, t, noise=None): + """ + Add Gaussian noise to signal x0 according to: + q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I). + """ + noise = torch.randn_like(x0) if noise is None else noise + xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise + return xt + + def denoise( + self, + xt, + t, + s, + model, + model_kwargs={}, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None + ): + """ + Apply one step of denoising from the posterior distribution q(x_s | x_t, x0). + Since x0 is not available, estimate the denoising results using the learned + distribution p(x_s | x_t, \hat{x}_0 == f(x_t)). + """ + s = t - 1 if s is None else s + + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s ** 2) + + # precompute variables + betas = 1 - (alphas / alphas_s) ** 2 + coef1 = betas * alphas_s / sigmas ** 2 + coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2) + var = betas * (sigmas_s / sigmas) ** 2 + log_var = torch.log(var).clamp_(-20, 20) + + # prediction + if guide_scale is None: + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + else: + # classifier-free guidance (arXiv:2207.12598) + # model_kwargs[0]: conditional kwargs + # model_kwargs[1]: non-conditional kwargs + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, t=t, **model_kwargs[0]) + if guide_scale == 1.: + out = y_out + else: + u_out = model(xt, t=t, **model_kwargs[1]) + out = u_out + guide_scale * (y_out - u_out) + + # rescale the output according to arXiv:2305.08891 + if guide_rescale is not None: + assert guide_rescale >= 0 and guide_rescale <= 1 + ratio = (y_out.flatten(1).std(dim=1) / ( + out.flatten(1).std(dim=1) + 1e-12 + )).view((-1, ) + (1, ) * (y_out.ndim - 1)) + out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 + + # compute x0 + if self.prediction_type == 'x0': + x0 = out + elif self.prediction_type == 'eps': + x0 = (xt - sigmas * out) / alphas + elif self.prediction_type == 'v': + x0 = alphas * xt - sigmas * out + else: + raise NotImplementedError( + f'prediction_type {self.prediction_type} not implemented' + ) + + # restrict the range of x0 + if percentile is not None: + # NOTE: percentile should only be used when data is within range [-1, 1] + assert percentile > 0 and percentile <= 1 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) + s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + + # recompute eps using the restricted x0 + eps = (xt - alphas * x0) / sigmas + + # compute mu (mean of posterior distribution) using the restricted x0 + mu = coef1 * x0 + coef2 * xt + return mu, var, log_var, x0, eps + + @torch.no_grad() + def sample( + self, + noise, + model, + model_kwargs={}, + condition_fn=None, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None, + solver='euler_a', + steps=20, + t_max=None, + t_min=None, + discretization=None, + discard_penultimate_step=None, + return_intermediate=None, + show_progress=False, + seed=-1, + **kwargs + ): + # sanity check + assert isinstance(steps, (int, torch.LongTensor)) + assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) + assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) + assert discretization in (None, 'leading', 'linspace', 'trailing') + assert discard_penultimate_step in (None, True, False) + assert return_intermediate in (None, 'x0', 'xt') + + # function of diffusion solver + solver_fn = { + # 'heun': sample_heun, + 'dpmpp_2m_sde': sample_dpmpp_2m_sde + }[solver] + + # options + schedule = 'karras' if 'karras' in solver else None + discretization = discretization or 'linspace' + seed = seed if seed >= 0 else random.randint(0, 2 ** 31) + if isinstance(steps, torch.LongTensor): + discard_penultimate_step = False + if discard_penultimate_step is None: + discard_penultimate_step = True if solver in ( + 'dpm2', + 'dpm2_ancestral', + 'dpmpp_2m_sde', + 'dpm2_karras', + 'dpm2_ancestral_karras', + 'dpmpp_2m_sde_karras' + ) else False + + # function for denoising xt to get x0 + intermediates = [] + def model_fn(xt, sigma): + # denoising + t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() + x0 = self.denoise( + xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, + percentile + )[-2] + + # collect intermediate outputs + if return_intermediate == 'xt': + intermediates.append(xt) + elif return_intermediate == 'x0': + intermediates.append(x0) + return x0 + + # get timesteps + if isinstance(steps, int): + steps += 1 if discard_penultimate_step else 0 + t_max = self.num_timesteps - 1 if t_max is None else t_max + t_min = 0 if t_min is None else t_min + + # discretize timesteps + if discretization == 'leading': + steps = torch.arange( + t_min, t_max + 1, (t_max - t_min + 1) / steps + ).flip(0) + elif discretization == 'linspace': + steps = torch.linspace(t_max, t_min, steps) + elif discretization == 'trailing': + steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps)) + else: + raise NotImplementedError( + f'{discretization} discretization not implemented' + ) + steps = steps.clamp_(t_min, t_max) + steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device) + + # get sigmas + sigmas = self._t_to_sigma(steps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if schedule == 'karras': + if sigmas[0] == float('inf'): + sigmas = karras_schedule( + n=len(steps) - 1, + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas[sigmas < float('inf')].max().item(), + rho=7. + ).to(sigmas) + sigmas = torch.cat([ + sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1]) + ]) + else: + sigmas = karras_schedule( + n=len(steps), + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas.max().item(), + rho=7. + ).to(sigmas) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if discard_penultimate_step: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + # sampling + x0 = solver_fn( + noise, + model_fn, + sigmas, + show_progress=show_progress, + **kwargs + ) + return (x0, intermediates) if return_intermediate is not None else x0 + + @torch.no_grad() + def ddim_reverse_sample( + self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + guide_rescale=None, + ddim_timesteps=20, + reverse_steps=600 + ): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = reverse_steps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0, eps = self.denoise( + xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, + percentile + ) + # derive variables + s = (t + stride).clamp(0, reverse_steps-1) + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s ** 2) + + # reverse sample + mu = alphas_s * x0 + sigmas_s * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop( + self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + guide_rescale=None, + ddim_timesteps=20, + reverse_steps=600 + ): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps) + return xt + + def _sigma_to_t(self, sigma): + if sigma == float('inf'): + t = torch.full_like(sigma, len(self.sigmas) - 1) + else: + log_sigmas = torch.sqrt( + self.sigmas ** 2 / (1 - self.sigmas ** 2) + ).log().to(sigma) + log_sigma = sigma.log() + dists = log_sigma - log_sigmas[:, None] + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( + max=log_sigmas.shape[0] - 2 + ) + high_idx = low_idx + 1 + low, high = log_sigmas[low_idx], log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + if t.ndim == 0: + t = t.unsqueeze(0) + return t + + def _t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t) + log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] + log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf') + return log_sigma.exp() + + def prev_step(self, model_out, t, xt, inference_steps=50): + prev_t = t - self.num_timesteps // inference_steps + + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_prev = _i(self.alphas, prev_t.clamp(0), xt) + alphas_prev[prev_t < 0] = 1. + sigmas_prev = torch.sqrt(1 - alphas_prev ** 2) + + x0 = alphas * xt - sigmas * model_out + eps = (xt - alphas * x0) / sigmas + prev_sample = alphas_prev * x0 + sigmas_prev * eps + return prev_sample + + def next_step(self, model_out, t, xt, inference_steps=50): + t, next_t = min(t - self.num_timesteps // inference_steps, 999), t + + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_next = _i(self.alphas, next_t.clamp(0), xt) + alphas_next[next_t < 0] = 1. + sigmas_next = torch.sqrt(1 - alphas_next ** 2) + + x0 = alphas * xt - sigmas * model_out + eps = (xt - alphas * x0) / sigmas + next_sample = alphas_next * x0 + sigmas_next * eps + return next_sample + + def get_noise_pred_single(self, xt, t, model, model_kwargs): + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + return out + + diff --git a/tools/modules/diffusions/losses.py b/tools/modules/diffusions/losses.py new file mode 100644 index 0000000..d3188d8 --- /dev/null +++ b/tools/modules/diffusions/losses.py @@ -0,0 +1,28 @@ +import torch +import math + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2)) + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, + log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/tools/modules/diffusions/schedules.py b/tools/modules/diffusions/schedules.py new file mode 100644 index 0000000..4e15870 --- /dev/null +++ b/tools/modules/diffusions/schedules.py @@ -0,0 +1,166 @@ +import math +import torch + + +def beta_schedule(schedule='cosine', + num_timesteps=1000, + zero_terminal_snr=False, + **kwargs): + # compute betas + betas = { + # 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, + 'linear': linear_schedule, + 'linear_sd': linear_sd_schedule, + 'quadratic': quadratic_schedule, + 'cosine': cosine_schedule + }[schedule](num_timesteps, **kwargs) + + if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001: + betas = rescale_zero_terminal_snr(betas) + + return betas + + +def sigma_schedule(schedule='cosine', + num_timesteps=1000, + zero_terminal_snr=False, + **kwargs): + # compute betas + betas = { + 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, + 'linear': linear_schedule, + 'linear_sd': linear_sd_schedule, + 'quadratic': quadratic_schedule, + 'cosine': cosine_schedule + }[schedule](num_timesteps, **kwargs) + if schedule == 'logsnr_cosine_interp': + sigma = betas + else: + sigma = betas_to_sigmas(betas) + if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001: + sigma = rescale_zero_terminal_snr(sigma) + + return sigma + + +def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs): + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + ast_beta = last_beta or scale * 0.02 + return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64) + +def logsnr_cosine_interp_schedule( + num_timesteps, + scale_min=2, + scale_max=4, + logsnr_min=-15, + logsnr_max=15, + **kwargs): + return logsnrs_to_sigmas( + _logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max)) + +def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs): + return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 + + +def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs): + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 + + +def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs): + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2 + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + + +# def cosine_schedule(n, cosine_s=0.008, **kwargs): +# ramp = torch.linspace(0, 1, n + 1) +# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2 +# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999) +# return betas_to_sigmas(betas) + + +def betas_to_sigmas(betas): + return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) + + +def sigmas_to_betas(sigmas): + square_alphas = 1 - sigmas**2 + betas = 1 - torch.cat( + [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) + return betas + + + +def sigmas_to_logsnrs(sigmas): + square_sigmas = sigmas**2 + return torch.log(square_sigmas / (1 - square_sigmas)) + + +def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): + t_min = math.atan(math.exp(-0.5 * logsnr_min)) + t_max = math.atan(math.exp(-0.5 * logsnr_max)) + t = torch.linspace(1, 0, n) + logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + return logsnrs + + +def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): + logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) + logsnrs += 2 * math.log(1 / scale) + return logsnrs + +def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): + ramp = torch.linspace(1, 0, n) + min_inv_rho = sigma_min**(1 / rho) + max_inv_rho = sigma_max**(1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho + sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) + return sigmas + +def _logsnr_cosine_interp(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + t = torch.linspace(1, 0, n) + logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) + logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) + logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max + return logsnrs + + +def logsnrs_to_sigmas(logsnrs): + return torch.sqrt(torch.sigmoid(-logsnrs)) + + +def rescale_zero_terminal_snr(betas): + """ + Rescale Schedule to Zero Terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. 8 alphas_bar_sqrt_0 = a + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + diff --git a/tools/modules/embedding_manager.py b/tools/modules/embedding_manager.py new file mode 100644 index 0000000..763f3dd --- /dev/null +++ b/tools/modules/embedding_manager.py @@ -0,0 +1,179 @@ +import torch +from torch import nn +import torch.nn.functional as F +import open_clip + +from functools import partial +from ...utils.registry_class import EMBEDMANAGER + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +def get_clip_token_for_string(string): + tokens = open_clip.tokenize(string) + + return tokens[0, 1] + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0] + + +@EMBEDMANAGER.register_class() +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + temporal_prompt_length=1, + token_dim=1024, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu()) + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_clip_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_clip_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def forward_with_text_img( + self, + tokenized_text, + embedded_text, + embedded_img, + ): + device = tokenized_text.device + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding + return embedded_text + + def forward_with_text( + self, + tokenized_text, + embedded_text + ): + device = tokenized_text.device + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + string_to_token = ckpt["string_to_token"] + string_to_param = ckpt["string_to_param"] + for string, token in string_to_token.items(): + self.string_to_token_dict[string] = token + for string, param in string_to_param.items(): + self.string_to_param_dict[string] = param + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss \ No newline at end of file diff --git a/tools/modules/unet/__init__.py b/tools/modules/unet/__init__.py new file mode 100644 index 0000000..3d755e9 --- /dev/null +++ b/tools/modules/unet/__init__.py @@ -0,0 +1,2 @@ +from .unet_unianimate import * + diff --git a/tools/modules/unet/__pycache__/__init__.cpython-310.pyc b/tools/modules/unet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9905b307d94b3381b27d765769632e9fc1494ad GIT binary patch literal 235 zcmd1j<>g`kf|lOtX-Yu)F^Gc(44TX@8G%BYjJFuI z{4^P(_)GIrOX5rOG86MMa}!HaS27ea1LeWQuRLd~n9$@i480Y-lw8~J= z`0~uWl>GAI_=5bRlEkE(RG36ea%pi%er|kTeoAUFOfWt)FVhiXcuYxteok>rZhlH> j4p0He>X`WW%)HE!_;|g7%3B;Z5Ggy56N^E9;9&p&BZNab literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/__init__.cpython-39.pyc b/tools/modules/unet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..141624b3f4ee01325718d5fb6b47546120e64760 GIT binary patch literal 179 zcmYe~<>g`kf|lOtX-Yu)F^Gc(44TX@8G%BYjJFuI z{4^P(_)GIrOX5rOG86MMa}!HaS27ea1LeWQFAHa@n9$@i4n9#gTM~JeR slKlLf;+Wk0l+v8k;uw%2G4b)4d6^~g@p=W7w>WGdQg$HAJ_9iW0PUkK>;M1& literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc b/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41b338a8a2034096911cbc95a7a9b8270cd7ae2e GIT binary patch literal 15790 zcmc(G3y>VwbzT2vrl)uIGqeB20vP-;Ah{$VQj};IBnS`^DFCGLQS1?v@$7WZ?##{) zuX_PvtH+|Wph>#C3RM-UVwdC0Dt5_`osUya97{!al!Grqx=!u4{^_)Dw9V7~M8%$-0%dWNEB6 zR8QqoNXP9&ZTKq7&DuzPG(QSlGC%ea#kTCBJBmHz=+*SPp3eZ2vWI~gjxgiEjM$^V zj7FFVV8-k;FzE<02~5Tw2WC9NOaU`tPXaR;VWxqZvZsNWjxaOMXlHg)0e;4w1%5Wf zudDevTfL&>=I)V0m+q=TQaD8+ij5`bo&i)&4Ms0Fy`s4Ne0|Ze?Mh<_s7z_uDXp}c zm4;WiC5o+8nc)zT+<*s_^Zaj#&tDreB3x!If;uQ*k)@TID3&q=x;JX5&L8z4*clE$a96p%diLq-*IY-q*DG$N+Pr?YSufwd_Cn!SrC~R3xrJ6! zc*Vt<(=E7OT6Mi zl^_gAp;4?mkb+exgyTj!RVdt8E!M&kIb3moy$?ssFFCbZp)m078I%eVAcg=aD#;sD zNvUIo#Fc>C@*yYq5!4u%YZKe4EZkVR`M~)mQfx~8h!FycLr5wlLnJ8_(_Z4q=e51N z_1_l~h0^^{OHs5^NDm`1C?qK?jZ+&!(;+k+LNm4rP4KQK!RMOVBH@9=&J)?_q7S79rHC?UD29j$m2Jq zwBu4aXufW1ofZoa#g=?Swx@dvdN7Zs<86yax(SW_*Y)<0XZl07xvsT`P@eRMHn}|}FNr5XiTD7)696UdcHIxt zoP^YlE;WG%%53!wm=7B|TI-kYPb_8|+2hMuJC}XrbapYjjI0fDtIL98&mrZBk0Ldo zM$J>TVtvsro~CHz(i~CN4)wVgt;dB^wI-&9dm4$b{{20CK_zQHEZ)B;R2uGTU3{3O-$Ovzi8BDX zR5-=9n%#a`cSB2DD1O%+@Ay3MaLNJ7&V8l zRk4<5mA2W{Hf?UZ*i4`~pz1pYH{iDJVH2=5KLIVIw^igcc4n#%HZsf}XLMu~0O2pb7ij`#=xR|w3rD@?yYaDm_=!6ks4C0=4$ zuH5TPeUjiW6UbG1nW_HD9bB|qS7`W(tE_&FfG-QQ#TD_FnEQT$A0X(ju->};IPfm< zb6BgET(MYW_ztY#$*p4R8XI|b`M{O?36x41_Ev625Ld#p<-?)!(>3(cFE?H9yYni( zTPv^1?xmGCx1;m^EF9eFNE8W|CLUX2HsPtl^4@BNa}SbsW=@kPcxa68;|q7;c7ynX z0PqL4k8+{Q*F#$y_myE01r-(^Ctz*|GrKFvJ&Mwv5ogJLm&O~7IMYsUG2(`Tk*n}3 zUA(^zbK>SnsV(n8-75Fst0`~)_>b{(_Zev;m6GVW`>~Lqkxh2FKM2s*opkY*l*q-y z0xiV6IJR;1obt&%+~BLQ&R7SVy5=Qe=fPTrC8ur}>zZdF4}*Z4^b!{Hrm8A8GSq+A zuycC2V$R}qg1RTChCf3C`cB$n)J67^auw92~Za=9Uv*_B+L zyXeqW6uKDzx{Gn?GENYF+jwtndaHDHpRg1>r9iH@WdLwkMowPSW)<*!p*V#D$_#N)CRT>tN))j^-;mlp0onEgw5sT^d^=o=PN`M_B(@3tC# zh_upXn1Ts90{s*I8tU5AwjoDLBWH#id5N#Gd%$5z;$(`J{Dbw+Gm|)T(C}O~vg*p2Zn)Hmcd@NdAaTz2@fIn7X(#U_g=q;~Z>87%a zt6uh2#ao6kl@$)%2Lx}j#fF_l@PknaH`{D6pi-=5i?!M#YmR8L_|c}2`4z`$Em&E6 zwnPu8!ALeAm$S@MJnV?;T1R8l1xpMr6qiw`v5auR2viLH8MOsD$|>&>x63^JQTEKa zMscmBy94UY$xC;&z#z4OK^l9uI4PsDpe;`MxEd788}J?6QBbI_bNXpwm^<+-=<{Kl zhv~NA8y)P|*vDg@jvYO=siIwMMd`#h=^)aKLP^z>J_c+8TUk}bN!##|K7}#M{L?b8 zo>I_9(obx{9*}l}iM$00pqPMD{5U7RED78FAbKA1Su06zhOL%BjU9tOV}dfgdsBTA z&gUD7Z=(0q>&okDS+f(i2_Hv4r>t|iB?rd%DL)yFF^TB`??p~maD&mSnC;zeIp?Jx zBQ4_|15vJ!iQJUv>PUm@(^%!79Kvb_OF9ZjfQLJwd>o>_0XB5SjG)+}Xcp~K(e)r6 zZQ-#%T}8sGuhzUut5$JDD0I1D(J}~K=sEF!uV#BMD+ z!U;@PgIiM%^it77+#?1hgK4M4@0e zONir9Maux6U|-@GLH{LEOGutlk-Uk=FaAL@V!;3Oa3PXze7`_x+A1 zC+t~!(w@4bLYdFm)AkHfF{!@e5X7D!eeIItT{$-p1->LG=n9wPJ<9#wPaqrGOkACi z3gvZ=t2GQwuzKk1KFMS^?#QEL*&~qNa zGla+eN!VzJWvK8{vc4ru-)iWLZOkc%oNU$od#L;Uc|b$%1Mq^7rruHt3Eu!jOa?IN z?*oK03OEEEiV&Hf1ssNlq!dPcm~Mqp_&`cw%tsuikoHr68GjgX+#dm)@W%ir{WRc| zp8=fq#{p;j3BXx@5^&C+2HfM%0D`7!thyJez5W#7K7ST)zdr|0M)-_r7l=mUuf)1YApmz8$ zMly@K-yib+eb>syP@dKJk{>jku)uVgHyvjFqoLBR>pkUx4mJ>>4Y32zDs61$MYOd(Q^! zP)XAiYs%66xvYP-Uw@Yoj>83PGipd7{)ONePP;r5gB{L2Zy4T$LxK3mG-fy-$&T*TiT7U;Q6}| zA58K=a$!I8(gR0KcOEnA0(#=8U&E+%v5MRUO7(rmja^07dhp1@@{ghyNsFD`Fi0sW zl=JMOl=P9Fq>pTq^y^zCZ5z|rUoZo14@moH=h7zjlmV&l-zIfKXgj5DAJ`#t)RUus z=PyX+TcvA=LzFHgiEG91{W4&4Ff^C$Wnr(J4 zVT5~?PsZCgqVb^)Kd0T$;K%f>%{5dB-6M&U-imdqh7eDwS#OygO^qTH)wYp>pY2Ups3F5DVEX*`T{3Pk4gBIC#|NR*_BK6{u6Aa0-pn}XuyAUT! ziB?hG+Ymp-X5>pw3*s>3E4CvxiDN_qrum1MW(Z2u2q>kYQbHuls#O<5OqQKk2BRfYEpQkPtxPAu zDaV${1i3)RP9F1;vq6&W!vkiMKeBj)NJNYZifyW3U+iF zZM6yrRmsT^L}`YAFfAOnSzDET%fqMBd;wkLz_?Pma4oPnJ2-0`I~Z???iA@!6vWRq z8#hndLHdkcL{#jibLQsK%ZL>KH94@T!Y+6XN(S_TJ-e0_ie(Sii^jr*(r*(?ZO244 zXM0OWE~S$LP4+*qC~7;=uOP=q;!Xy0yI{ixC#A0(%v`(jeBtWz7hitq@|g>TtC!DQ zx^n)d%NL)&EEqq-$!?_~XMHen*{K!D5u7GSp*)8x3PD;jR5&#fTv#M(m@@8>5E#5# z6;5uq5X5gem8E4*JVn7Ub{wQ2c}&P?k|ldW5mlCsUyT;`K$JGv(?#|Ui!L&}^c?4D zkaRF_9f2!UfxcF0i3+vO-KO^{gg1-ll zpDS`cX**GB*j+}1xfCiKR~|6thx=0!weyy|dKYq9tpsW{h=myo_vX3+F;^pYhk@cJ zNUV$lQFD@=?hj)}GY^a*wjY1+SBq=Ams(L9Lo-xEozad08i>gm221d~YZ->@djjcH z+(7&;u4WP8isK|9et#~JLd)H7gILtQe?ry_v22a!TPNa{p3+IXiC(&IoG>O35w!4@ zxE{wTDd>1_Ugu4R6F9D!$E^oL)qn@3IJz<8aru7hy?r{OVdy2RYZyWNJw_E%s`0Av z$HxA#UJLWvW`exX@K*L?yc>TnYtCqvaRPlF##@MGS-1BvA4fj^-oAvP&VyoH=OG2( zH1?5dT=sc)CcZZJU}Gq}f&oz}B^qGe&j}nh@z zYE``+dsAsov?tqB?P-6qJp-KfCM`IIoW^YpCrde7X{X^Hz!@7tXy!HTl+wUW2S3e# z`RiI)hWzgEau32r)tGqRi@VPwh5KEIk;dVcp-(%3u={KSvF-M(XI7JMD1er2ZfKn$ zmcvz?djKBcuWCwccxGpqP9Jh3I6@f35yJccZY;v>1uo5c?R|CT@ z{XGbg&EvawTAq{PkYUc>^Q_XzY%1dOjPZI5txOPWTexztACcWja-%v0M{UZV-i#q! zmyib&zw1pS1@{}*8I;Zpk0y{xgkx?a9M+kYHBxYihc|egAtCc4=x@p%g`1XuE3zM# zr*I_{z2ZLGs)~@)ew*&n970!rMjScbgX!16?^&g@2cti*tu)P=EN!Roo|L2v0lv<> zpTfISerkZSy|PB1vam&;G6oL!(_xmB^>|=s5MD_~d^qS2M-=}MdP90QN+lHtFn0F& z!=Pf=A0D7$ziij1Vz@hUa6zE#fNYJBUnlDyKx?KBNPNG|0fa9PfXZRnSAPVD;C>$h z{vc-6K|h45pT!YS+8*D~@EhkZE|pBPlVXMd1rCZ!Nj{ImhkM|yKAlSIIR{L-`n=m7lYHm|+AHk`4*yDX*PIMl^8`E|g zB}XOINBpB*dQo>j`g@p%R-H#=plTuP1!2@>+6|U|haRy~9^G1t2Wg=cfnUIE z!n)b5cmHtpF@MrdV>ZZu6GnDq8|Fx#uL$2BM}LpQ9Un(PNa>vPa|mRk4Z_7~e{3@W z9v|;~fV`K%iF2s^1Zs~ws5XKZpFo`_QRiqoCr8g7KDfyvJ}F-*wQcY5p7M_)gaVz0 zVAN;*2=S79b0;a z=XV^T&={oOEYHwA1kK9o8`>Kx?ojY*qsG4s#BJwh1^%6yAHfaLMd=v9j#zP|ON|@! zvBjcSS{DBk-2@p~rQF0V_ySBm7mh>ddCOh>W2o})u}jG7j$_K&Lpyhrori+awa3uE z{U7K-+yb~NJiz-ii`i1Pj9ZPjc&G9BVz!(E0@<~>flO2ebPknExRrYajn`&SrHc!% zL`Hpy0$=k}1PKDV@iJm?pG~^2*xG~+J6y$ta9Z%~TxlD^J;xs|$~cUqgZ$ypjC+Zl zky+w@0pKdGQ!c{H6<;Ij_X)mEunrK!=sUXNb*42Oe!j|*I|R}ge2pn-5%SC|(2I@R zj0XrB!h+T#JU|?chMh>mlsmEbErNed@awD-E7fp#Dt?=Y-y!(71iwe{y9ED+;717l zfI!-(zll_S!mTt`-DbsxkKS_(LZdif%`XwqRt@4ISH!;{=5quzP6JC)CMQ#n4y(X4 zg*jVuiZ>ncA6RF!5KV~!-nN80(p5Hpo8Uha+#~oj!4ClBP7GWec@c;}y&d}ee}kEL zy(yi@4l?=Su)~78;$Z%w3?4uH$5;u01l*zy(c*8Ss`zO(PSOK!E{P%zl;H~nR_M>l zlj6*lDIW$8{){gfXK?|4C-Dkpv~&vaVgj`ksO9jAoq8*X%K|qP`zl{|lPzNnawmv2 z=mLw+5-&wc4+QRiJ{>umq1g?OzMLQGo8XIsnPQ;gN)g&IFq2Q-rmHTK!(k{T$_Cmi zuJ{waNJgsulqtDx4 zu(WwAka^g+us`>~HXgviDz{(@hjnP*@(kTn$21d`E^{MlMn8(Xgkzco%QOQE8GUIP z^*FDG>KWCF8R%gO^(>sXTl$px1R&4e*JdB6UrI@gr2DaupqY9Wk2`@Sf*andr><)P z=Z95Iq&#`msR2V3>Rjs>FnV~|3`gzqRyhvg@pk~W!HsnBq-=BaecIrJ0=GZ=t#AWW zyeM0p+0u${z*V+^nlNrfP1gMQU`<|rvuPg0dEG6dO}3OJ`j$Q{^zmug$18(<$nk7_ z8;lY01zB_BJJd{}=FiHSU;GX=hf(tvcX`(c@V`nt%!oc`MuGiTJFsKGeqHuKW4~WJ z4g9;3if;~5u`73C^4|iyyG!0}hrg|EjNX+1L4~ugAYrUmX1|6ivup zDp4<2{||eatWL>(?)WMUmQnd1iD{h1a~|VX6gh$@JaogVDh=5TPY0$d%~Blu)qI8@SoTz!$UIilb_lV zS_&KdA*$skUqC3UTj4AMg|aN4-hq<;&ynFRfxn3I_VpBBgI$IH<`et*;u#c(SpvC@ zrkkp**k}C_;Jo!sVCJm1*_K!CjLFo0r8bu7vDBY1 zO@H4QO8sds9Uu7?Z!Q_S`5x|49@wczx9!Tg^rfwzencR=D45uGHE7^Q5Xb-E8$q83 zQ+@%D%To(e(F_goi~g694t~48{gR!jk^iWCO8Jy#?kw4gvHqL#sm#3Kn{j&wSHZmH z&>S=6{j-HeBmAQl{~6P4xpD&M$Gc?<<<&+hym%Q{18Mbqey&zoT<9Nhgnzh2fFGCI zWLP#oT`pFHJiEY476u~iav>RrfYi_@% literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc b/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d71fc81a7f640232588131c36ecf5ed443da653a GIT binary patch literal 15915 zcmc(GYmgh+bzZ*!G%y%&fEmt(=*W&r z&JWinyX)_}-FVD{lGawmK=r}B=RW$}d(J)goO3TKqoWxGALr<2YJV;i<$o|U{7E2l z0gpGSDvGVxYEvmkziL^Pd9AF;Up@XcT8VN(WgV@VY^BO6nO2)-E1@Z>(n^;z5~DXq zTG?_|rj6!kD_71Tov@S5eK*)T)f{W(%X#3;a^Zc7owhSCDt5+6G{!gd@&qs=b{3dy zjM)#&sGS2Q7h@)Y*=LUdGZtf}fXUkhUk)n*RQwTYOfb|#&qxTipcoNhSF8Hz#(7j z>VdkY28v|_8eqa!t}FM|h5N}c;jMw@(6p?&TlXz1)Lb`AU#Z-61m6`JE}>R)@9UxA zIn7$A+x3>n;|2FIJ(g8%Ry@zL?klf8dGS3rpYt5y-K=}{M*HS-Zv82@4sxBFe!JcD zZnoO?TGR1vuDOmcfwkt=D^ZpC)w^O8o$kluB>@yQqh{6hnVo~4?@?OSSS$W;SLL{CYF!DmCx+!RJGxI*r0au8ySXrc8tMbP<{LYhFUZGIBIf$7 zP+bYt+k$d4H2IQLl6eGAkLh!Jk{L5z9NON)fPKeM!0rd^BW4%-dZYACqO_+LY*Q)* z!%n=IhGIzCscniO9j2e^X@mzR+C8`c@KrqWQ{`;+J}D$Pa}*QrMSy{l2=qXkP&$b~ zv(*)?T|^$gS*4SdDnko&TN}qrTe@#F5>!u{R47U0Qh^#I?8JmZ+BS7y%^)dm`$jj7 zClw^H2sNZKfw6^@i#FYnE%ueAA}W~+f<(n6f91Cxyn)YsRZ!j^kkrdiHO6AS;*~~4 z7J0dJ!o;#uu{~(8YNhFj9PnZ4db{SgDpm+>aCD>mmMrM%M&~ECTWGB@zeeBRd zGDvNJag8KIAdS|kAibq_GAK(0nZVr8J0rdsjM(Oe))_&0IvCmFPMKa1^Pog<6^cgz zs@(m0AEqgORIVNweHoU-rjFM7!mAUDMYni*xoDS)k3LacEG{E!gWc+~;Mhw@ z`QidnL(14Z->kG2?aCAH0{+kODC@@ta~H40`SZ;-1jBm*iKzCg2iSeRXrGnuX%y|= zE9=EZuduO4p-n0!VNrva;8F&%w_L&YNf8n600=THvGj2Q za!$Mlpp=aS>;TfCx)`c%s9K>82^Gg#hXNz0146@Y`(9{3oZV2Tcu6gKf#rSu^mgPt z?qiBNt4^rXT0zTbSqRp8e%G{=ba9%EO9W>K9;&lWp_J6bIx9COh^gY)_Mu4f^X7Ec zXrD5!%N;a5TxWd_FZxdEEIS$LtAQ4aG0maAXc?^M zjG9x&)Um<(-oKq~vFdvhYy3L8m;5B7O*K~##Ozw*95g>~#;nXg+xGnLpuh5M*pv;D zND+;lhuK{u3?zGaXXWtNOd6Dt_t3>o)quaS-FM)gC+R3@(s77f2r_9e4!Xei?8ceh zq%d7=T&N%KAmFwnR3)M9Z~!9vSEZAArSAD*vPm}<7NYH}`paQL`tdLg;ibE_^q=I87LHtY9Q$(Q z;9=>H^Cj|iG(f{l&-AlahSImgY-|Wx6636K6g{2ImlVwP@#Td&u>Tv zxTD_6QMh5}8)8wmLipj^Xmu%|$)`m-8&-gYm&zfQjDa~Ll^OyDwY^rv50KW#QkjGy zoWp>l-}T8IL++_EGcpankFUDf2hTGlaWciV`*1s^j8Y#@&yGn8VsNg1X|VRka1Nw?zJiod*my_YI3uKC3~b$=NK zTv0f5SP(ucR$RM?@XKmj_+GKS%8*W_S*$dhkFGnS&Em(}LgrT-XLUYP#AjR7gF1|4 z^YLg7`ihTzc0=oGjQe1D!Cm7LigcC{Ef#9Y$b(W>!9^+VE5z*~qd&}^DQOCrwVv+{ zs4tNV_qEU8`_?}xqoSbopTgs5P%1COTk#HoLVbh6r-@na?=zq;fHx4OJ4Rr1 zVFkcCNccMJgv6GLc8L|Go7|!^O0x_l4O4nGurzFCO%=9n1W2F9m}TCVdG)-4KGH#I z3s#1-Moi>0pa4n}5&R?n~%-Bw(#H{U^3s0yTC59+U~n@a`@36*#&tD}jmL zvo9$xsWr_`*(SUy{gSf5#g-l#9%DN4m!z;_HW)uk*#6yB=d$#iq&423K$I(- zMAm*A9*<^vpgcc`m5d4LDj)#9^n~&~VDv3aLXXJ^$}Ng^#jaL7AI#C_-x8{8NMu@T zO~1a{tUDrNc|RNzrq3@quJgjGc;Y@@A}#@}AM3L^UWr5gRIhevG&FrtLEz9`3iYPr zB0^YScfy3<7S-j@fWEZ_B`eh72Zg5V+_BoO19xT-@xU-~XW0=>XtElds8*;~D?S>p ztu~#W8g7vw?xS#l7?z{MG_a{cqgC-%!bIJ*9r*lMU8_A0iOO~0A>Bhx+Aa#z?6huD z(duK+K=P~SU}2f+oKEL_T3sL8IreH#H=rBof^$V&WS`uS0&v5IumoE4yP|p1A5=cx z&%^1~U=XXeK8+CGlmeZs!aaup41q8rvlRQ8lnVlEcI)^-=#~44al7u^i4`p7QB)Ct z_58>_-C;G}LMz(ku?aMdQLPvHW9dsp$rOiCOB4x?0E8*avfEWe9=Wee<5RFNag<>2 z5~&}g*GU!f3LcNPhp8G`))<=;#8&Za`%rHA@oobQ)GDJA3?lJ-z)mgUOyIMOedA03 z=@F!}jQs=di@#$?6E})lDoV84A*q-d{ilU9BV+t3RP?Nj>Z|31RLV(A@_{h_+=Am@ zzhn(!(hCpB#5xz*1M>92&mtRHL6QlEWXJ=q!6?M>BTSWg*^>TMB=PymlS@(J_j`#i zJ$Yj(GM|1uGM_GsA4iocC%yNf8%TR)1dsO%m_@`yRCqMSAS1?tj0^o@YqJU>Fxxfy z0c!MM4$zR=9Uc?X)JBSx3JgHRb^z1CAwamSfFsb12(<-8z)^Tpij@msj9L5OFDceo zfOw9T53+!TU=(mXfO%j|1Y>~vgFN75Pyn0?#sQ~;3BZ|PKj3UI1$ZEs1_VuwMB^Y* z2ZKq#L%|H-;b0c9XzQILw$V9?-(!+%;QEWEZ7iuldP>QW)6!FlQCiE`KV9r1? zKB3QG*~nTamXyv(JKZ^j$k!3{e8fN6J%;BvW}t{58IUL<|G_>`@7PXGU-F*rR%4h_%aV=Uw0E}D|y{TA2jTfc3(&iQW(CSk~Ph2`%V;v@d! z_YChN{@eX`V4E4)&}7*={dWZ!=-A0k>W6n@Br~Y{E79D4)71hQZ+utxJX*YmtFZB2 zI$}F;Pei!xlQQ6SfxXJqQ%F52r5usuAT>p;7O#*cY>=H2C{BoZP@E8jpwxZZzle34 zfd)DXt#vFo5{%nZnZ8I9+7mvAG)*{>BRU4{XB5lypt9rX1m4$odxt_4gR%Az9}V z>Tnh~%5iAcF?%lV&t$Yrj`eJBtbv@L;EeylAPqa(m{wMF%;yWiBLU_bIgL60yD^Sg z=nT_7h*;XA0cI&~dCER?Q0bhC`+b!C-b>ti0ncJQ4<_e9&cb2nriYH0ZaSpv3VNEn zml9(bby9pCxhv%CHytSWOB#+ zaDTmjc%TDhDCX#bpdIv}2yi)y4bhe9eP}JwtfEipF?xa!vGRPjgCm~+O7X|ETL%1; zfki1$C31cS-bm(r6TzEmyS4fRBT${*sj!D1aE;s70NtF0W5DMvwsFCsf z*y{WGlW;#;l_e(}k>^d8%t(t7XIH7!io7WzKE;7@rb=3vg!x{#9r1UGWBdXJ`CFOh zgbAC#A*d1XHDa0IqX1#jzO&l)oG?eHsofGz&64?aKM#+dot3jV21P*1^DA!E!BMwt zK2mMFzVm`FBUO2nRu-F%#cLgmV!2^v&2ucoCU9h09*rHwyB7WXk{)J+y#8W2t;KNM zcUp*BRGPA#<=0zq>+)owov^%C(} z5PSnZ4rvshA^16hzXt$pmuXfBv$8c#V(IpYEfypE058cCe;=8cCCUZN!RLC7;91G^ zkFxf)C<3y{Li##kl4G;(hU1IPN_8bBzUEZ?HF&0awYC%{>n>+YGAB&jm1&X~X56)w z<+oQH*9*;MOthTT3}zT^tW2jdQ;sc@Db51jIeC!GnGMryA01(v{E@{uA`v0779mY5 zUN}OwL=)ykn!tnG*mx00@CueXSDVj0KlG>_gfh@bu2$T4E4*mr>g1Wm3@7<^S zUbqdBY@-_3>dA^;m~GX)s%6LL-I-Mjp(!a0VVq{j2V&uP?dF>7n=VasLiKjOz_?z& z@?4mqbWmy>I~-pXJrU_s6ecgW-P`BvF#nWYK@{w^^VIF7XAvO)YX8upin`#%D;dBG z5A0cHRcbyiG>t_IWzZ&^+=+>6&h(d#TuSGLnjC&;QQUT7P(hB5#GMOg_rOLC&Pg9R zoPO^5GuDk~u0H#rYfoLVZd`k6;ritdUAy|sH9ZkeQDVjZy|R+!Z&2Wc?e`IWa&Or#3D<_ zZ^VmxC?*>o=pp+@PFERFdYWYmOHc7p|{j6qV*sf8s8y zWG=PuxL&1&h!3UBwxNwAQQ>G5%;}M+ueMi1^SRZyZu)(qy|&83C3$>SdQ672!X#v> zR$mgIN8gb){Y9pvM)_r?{t>}HA^4{Rze4cO2);n@&k3ZmcrR02Y2sfH{3-!8koYwM zxq|*BQ(Q*k*9ra=!EX@!1A>1DP@b(&o^iZ_AX)#YB|4K3jMGFHj{Y&LPIB=^iWe90 ztlshltB0D(V~p}hthQ~-tCUU0#R`Q2ESt6yr$)VHM3k$M!tvxWWqEWUlBiwI$jg4w zOsn-!ZG?#^ld%U3G1ubhjsnAvlHzFsYEDjx_k$|n5ff*UImFc-^3{^K?y~nc6X7*O z#aEkFPx4na4FmBxT#R`wX=IUS-xIGTlT1$_T|`_fsZPF{OwHn!!``@I>?~@Yku@VM zThIRAWb%xj(X*soLju&-Q##Cdq!Em z=O~{6aD2jBRde-4T|-r41Ob~BpaBbl*qPhDy zBaKP8Zj-^(RsvzX6pta$;`9C#(s0XhDMIN^^C$wTR5bDq0%hG9StAQqd32MPB2qG+ zLywd8KDcfvI46g3JquS?(JyYyp?qB% zCQCb6yeTVbL-4OV7i97FY>*wI?VzkNpe<@Kpp8Mr!*rx2ZT)$$3kbaA}#+aT5{e zh;JOfr;8te>r5UorF2jDCmEt`oZ67ia`!ay`+9l2_qdDD>h@Zl!`E|o@UOYfY( z8F|#>{{!Ys_Yu4?Z|6~RQc`^)IN75Ybq}MzvplBiJ}N_4^HDDdv7U-+pBmJ5Z+`0@ zv0fhAUW-R(5f_1 z_gqjykQ;3fO3nvkTPe)(>F(P&_Xt@Sx9*|#)6sayu9o11iCr$rFmxVBb-)KU)Em60f@Xlsqr5QN!z(yhli(@ zbGU`NC|x615-W~$vvF%bu~_k|%i>SaLs*biYVGQpXTj+6;5tO!xZKw-qRJm(w}@Qs zM|STkyN?3ntCf*|`vC1l8{md$0=_!ESgaOnxH)-;_a;v-7HcISkX@f0%EV!g8hTw~=lc+XvWGa4}h~FXj zU4q{u__qZAj^N`2e@O6a1YZIuPk42A&1=_f_~?DdATo-7pOrpCKwC9TMzbRRH8DKZ zlg4S7k(5bM3iDAFn5Hmin@;7nBmRI*_gS$>Sa{nKZei1Z5PwAQ-w6IY0WJN=Og}Sp zP2_1{L-lUt^Zyhw`21urI4tD#zpF6yaw!lW$l zBDS~qI+raIE@vl9xO9QV4-qd}N)H6?e>opJn~~WKkG@tO8JOUU!*!Qvgm)%H1c0;L7@v|2z(A(4S~&s#n@l+WUsR3C@kjmso657jT#G)EGD zGWeyjO6BAo#S+0Ca@AKiH1Rl8yC$Z6d9lgGH76)@t!u#K;l(l>u*-Yv+C9zv6WIoL z-bF*Ux$;fg;N${#NZINOvK8);il30JKEACL-+*gtvgRjcO_;diKg*iGG+dJx-E0~M zNnUM>X_GBwi5~ZHS@x0H&|oz3TtO4x816-msQ12ESr^6#>Yig={1zWlTXnxCYiDKc zkH)nZ9$Nc1W$jT}dwuAws{1)vl9MH$VhKzNyxIL>iQgyjpC7{e5pPk!GB`4{U1PJt5@M7!iKV}oZp3#{|A!6F7X-~ z#`lS^kI3tx(LEykZB-E?1aec?nc_V)@ePtS(6v(GN?psYp3|U+d?a|uxB#6x7XKFH zFYh!$w*uy%O$!oE3_)UhU#6jXQ94Mmj5z$qZzz>R{t)66>qmBL%^iDiDZjA&Gl&9| z)P+)N$K{`)J3k!V!|(iTqVx-Ryas@2Xoyt|e;L={xA!|-wj=$=xsNFy!&PJTTYUdC z^D$*k@V%vjyz1pIOT(*dz5uHO;YP2V?WN}1*z-+!?`+<6CG*St{|(SmXv^chviaJY zTaB(-hMA$XdbvC~IN*r>;EV7+F0#p>YffvjV#o6V;F%bW5;&cPhn-|U! zj1hsIod1m#CI%pAJBG>Va&mO__!!Zr2(A-+FF~1r^NRx>*$IM&Q74>*Xy=#G*cZvr pJRgbl6A)a*I}H%uYLoI)Ak<7Ps~bka#P@LFWMQPRzwk&>{Y%j5|4RS> literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/util.cpython-310.pyc b/tools/modules/unet/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674f94ec0e993f161e51ee442ee8c56bab64c9f0 GIT binary patch literal 42320 zcmd_Tdwg8SbsxU>vG2uVu>irRsMU)iWQrCk%aUnXvP_Ax90|51%5GpO>*d}BxC>wx zoO?l1n}s8rvPt<9+HssmW5)^LI8E6mP1;Y}IBnXvwG%su>)37_C%1XWjeqRVPTbIK z9YqrR{hqmxU0^{ULFOd&`a2~vtPMWiI1)WYauX{IE(>4mYy@=O`=w>cRn`;0L&?&O^O zGl`iAr{Ij>{}!j{jN<>KQ*y@ef2&h=#_@mKg9&HC+44-n+0w9E+s{}tJCHKzY(>h} zAY~^~wmI98vOP%Ig_Ir6PNeJ%Qm%3M_Pb|xH+HwKU9C8~oNLz1nd?w)w{tDZT^p3U z9w`;)I;30|q}+g%>zx~razl`EBT{a3ZbHgULCQ@?+2h=dl$(Q;JxJN>Od(||NVyp) zw>Ymr$}57Dz0N-8R=ndBa`ro~M9wRNoLi7`n{xmu2ZEGWAmyNQJ5p{BQuZO`Rn8qq zxg$uq6)CTF?nKI+LCSul+~vFmDX$4qUWt@L&TEnK+92gNq}=Vi4k@n-QVt;H9_L=9 z+#94EM9S-(`;c;9ka9ax-ryWY%Hbg8RYpO71SxkQ<%shjQXULaUX7GDIu9Y` zp&;c>q`b-bE~I={ka8DN9(Eo<$|FI_YmhSSycsEP4pI&w`I8q)DQeNj6k0qwwevvC=ddl!qjb__jUc6xVS=Fd1Rcp^RkeptsDX&>u zxX9Im`_V@`UQIpqz~Zrn<22iINGjKlH|q0Coo3sso>H}?rG~0wg+;$KVo6*~bKioX zhgG(mP^oi?6|k_h8Z7N0Oz?3HBCUNz3)*)?+| z)l2rQUh3||iLE^gZ(&W>4b&U{*buurj?3MRpoz+!%%d7(#Y7dx$%Ha{M$i1PdBS`w zaUr3yC^TjH=KkA!YrgT6pFCMxSZ-h}QcJ2a+dM4=o@*UG^u*(CL%B~h-Da!v#Np23 z>{E|FR6W&fJDpQ*b*ZDg+OdU3nDIn?+4VY$)po~exM5=T@pkjRc5|`jHJ*Id}Y^pu~w+iBLldS`Lzp6z@M6T$98clSc4=2YLX(5cVwzZ+p~ z!M$gHTy)CxlTEL&==y2ZSX!vn8&hdN>3V2P6;XqqYd20+^*i}FNw2aUl}Dy3utb_A zFb490V*YFcOlQ9h#b3pX%gSMx=b9H}cWM$4p25J0l;n zoZ@+_zt-M}W8Ilp8R_L#Gv^Z0PS(ryGSWA5djc`bNxT=aGv*oN{=}2-13(hCM^{Q- zzGq3=_pXeg#WJq(xK(GgS4P`*(DrC=T<+GlUzYaEPHAPlm+z$knN>^_P43+DR^(?tFc^9{G#qq(o1oh>ev1<}5`LLxEA3X71yzzw8 zQ|?X3IIo+t-h@-a=;|`#PFQOF7);0*Okg~X)`*N~dSwfq<;+a{fN|n3X=CCF&r2)l z&7{ZK%SN-8T{n9_IXHWcwG~(kv-+WRGq|NU`O!o)gW=s3XRtRe?;6Y^M)-5m>+o3( z3*p_>YT}!ZuN#kX8IN&{$ESLn@%8ijJm+_wbNl??p=FLF-oY7r4d;4N+MGn2OSl}V zcXIt|>j524Z5_nsL?`4D4fKs=pl{1*F6u0e z!|q;ax*DnOLzzpgdLq$1AgJx2L&WH*TC?47sxfZt!?V%T7fEggXik4G)i&hDJyJyA zYG;(juBdAe+{y^))5Xj9%^`x{x9*)Odd@VTRKu+vBN=l2G6)VxUu~gU^SnmeYj)cHNR)=2o}3v8pRO)0FL*O0 zxkc%A)ddM_xI#Gm)bf%8O68}>L$n<~+h{K@HWZVrZga^`3RB`+3yrov9yK73;>u6$Cw18FmI^`Sh8oMOs;2>eRwZe zQpb9)IScYU@Qe+Fl$uu8qgTz+%>bEe5z~vAD{d<9CucD~o}ctOs(##0V!9j3x0X7m z0LL0&O|L^tbP|Axf>yob`j*>l`?+pIb=>MgbG|W^^s@rPUSr8syHGP$UTtA^CWC^@ zZl~$^MgSNbDb12ign3wc5w`dj$acBpGnhf!GL7Aa{o-z`VC=Q_aLW^Yv9Kr{Y-Yo za`J)n!afW7Jd?yiDfrhs(pcR4ijQ{YAE!Iv~kAXUnbdq|2yfEsU?Y5tK80v}GpdPk^7A2E zz?2qJ)I5{ZPEcF6*={!!%i|@o)oOjA=DO&1gi_RTl>aI&cM^fn3PV$svRMYukTP`l zTyc)~hl5cMVF)#G-eq28PEP~=&l#&&Y-`2}kTB+ub7>;YBsyP7B8HhoEY-7H_7m2# zm~+f}@@(R)F>4000?v18);eQ8VNpJy%R1Nvo-tXwi5+aU4sx^UUh(D#Ay(X9tOC>pdk+ZD{CYuo=dLiv+m)kA? zx6y4-N+H-xy@UCE5Q}BC2tw-T2+_^C<<7DTkTpEuQsbCfX6XoM{mk79wZ&sj?Vi6x z`lTF#Okxro&NgsE-8-)gtw61(YRc(n@^xK;q=nF<&8#qnT=l-$S<u$;4AsoVr@y*rbhC~`tjxVyqaQG<}}edq)EQjS#tdyi-I|;jSzLl+_dso zADEAXdVqlt$noip_t4_fLSqqlq2WBBR7a_&nf-1C(O9bY@YY2zRnh}Ta_*-CRH2^a zopHi9=6&O&x*K;hsS_X>t~Zk+>EQp9^54&pm1T-=ov0p|vF57>K${sLL(j+i@VsVD z(!t0Q$Ge9RfTxtdf@xxIGRfR?rQ9xqf3`7+nJO5m%LTJ!Yy)8;N2+Jgd_I@z-t>~@ za4MVOv%8bJeOAxglPtR!=O~7IcnaKlM>O&wNLa`n#AzWfmR56kMux@DbN_T~xjx_UT)#jP=KvCF?RrBkppc)9 zM8f`76?8KO;VRe8l&t|!VDqP0^FLxBLmI7SX$^{!gquNTF2xN^+2~#~?A2ph!xs#H ze2&*)T<%XGa42U?9aDH51K6Wb5)8oeZ# zJ4l<6hbPGP^^!ZWLGY5KPddI6HT06(AV^?sTgjbpTf$Jd3ChW2dy?+4jZ6%stnHh`5ekp(v6^*oSLB8P|?U^j6uW`Co_hvFpx9To}vz#e_%dR$Os|z(X*I3U@ zZpYSHgC3;Z8ukKU*vYBlh6Vditl9~zaWa;sDf`bCw^?PZ=I;LC)3;fr1MEBJGKov( z^PPwY+is9)hbF-QHMcgT*dAxtJY(WN81ggbHqK|FB|Kt_v6S!u*ot75!0BPu$RQ@r zPy&^j#x~?}v?Tm>H`A)zdQas*rPa+e8EaM!AU_(`Zn0TuRgN>~IPzLk8JRfs23Bx^ z0hgfS#w>nF{bSssq>!?I&MfW*{NxFaub-Tke;|(wjLc-!QZU?pN>mf>x(O1j^-4|! zh4^IECh$ju+ub#E=mYFLWoc~rQ|!Sf5a7k)ct&ygsGnr8%HTKxzd)%qPNBr_*OXy{7;VAatuBNnq0-bImSHZ6@u_98e z3QAOr{^ItKd;bS9J7vqg{byYvHL1@umZ=qXKDT*trT;BQSc{Xd6rNseS#GN$;eU z=?C6^sIu%fu=Kf?kiKKr)u@ELL;WfA(e<$BsB~tf%Lh&$sDR?1bfp5#z;Mu~j~uO& zVG4hYbM`D;?(ZOo35EyN8F@%i2b`IJX>Ty6qlIOs)Pb8kTxwC3z9Bl5v zB$C(W&J6srdjPy|u4kMBt)o8NNuTF#E#FRiY3#L-F2pB~yAg8vXU((LS^I4AY-$#> z3@v%aEB0(k*w27D#gn6`^GUD7TF((mN~mWivkqK*q`T~*4F zO{w|}E0-3YVhicEB)fi*02bG+7q&1y6kQ+Xawf2(5~f*t(Kd3IyRY21;vxdw&tHRW z5WVJw@tD2PJZ%noRlqHo3Q#~Ynt4c3!)ef1RWln8a+;>WKl^r@Vd^k|`=Eod$2!Gs zTrKita*v2=^XZ^4D)>pX=m0Y~U(Sov-Jx-GxZbE{mu(1J>6@j2MxBx#x>1YE)CE{bK5} zodh8b{tC~RQ^4jsuQ131R8vomrka=j4Vr}T4Z>5b+f2`#gCZOnGV=+0CEPoM0)iGo zssdVw71Xlb-MEKNmiN%sKxB3X8q_s%Psmlb_T^$-wX&l9DnKiiS;=B|%w2tEJ1l4* zdNocQlxH5qGr$2+jTU$C#1V15?svy!D8<=IsY`OYJ8`U12YEe3DtoN57qO{I{eq!B zj5lLukfO7tB)5vkedE-Hy!rwPso!EC`^+yhCiC>$j0tr4&Ud?Dy2uU9SlvyWbJce1nbR!};ut-JBcMdb2LyRG5HUCdY zm@&FLhRkqCtyBG&`^Zs|m|{T>QfWV3tva1Lwk-rX;S4sX&7VY(I|-FVL^286FtdcX z;tS>L74tSwW>N?KT(UR)lv(3@jfdxe-kkKu4ac6%3Z*6n*OZ!;PJ@*CnvE~~_NI6I zKBu+>QM&piwA26EM`5Ue-81Y`Kq4Cp>c1hsK7t_pF-|T>XxQ%~GC;#9S0N3WsZZf{0~z8>3mMvnh>)SOkfAi+|6J-EMgwFhc`gg01eyfH>c++tjq8tOksthGq{qz4hw1B1f7x<^X!Qh(= zNDc)JuOna@KMVd9E;OQ8KN z^;D>AOd8_%kk=AQ@!wfEq7?6rC`IZ@lwy~WvhC+VDbgD##h#b;Vw(~T^o~l1bqbo?W{A;09Dh#LnL(;`InsX@mq8RFoJwng;

t!HSixW zI){In9z4p}9Sm9w78ootXfx%h0tM$NdQLC;HfOeV4{~5(f{a1(Ld!Xqy?8q6sVqcU{<45p|Gm#P9CeY$yzy6WUTfcs~n#~ zf&;!wSRhV@dWOMw7{n8Q4!6IDNfHa2Vxqfu-4q9?pnq2oNZtPmPag5@`m$4-GW~R; z&Fzlk!#;uv)GF1d-7Uu;W^wj`4{TK&&GA9((k&gU?BBmXU=R_T>W;BQ)bYJjm6q;k zg}vR0Kv&gTSC!M9j9yh&7YxGG7QJRq{7YP``w;;{iq*+KsEMr-Sky^l z*Nf)U6SfWMd=7RWMcQ~+MYD)9-O1t8djXQo8k;)CfaqEMID?NdkmekX+IB=<8_6*nY0A|nn>42%69>6WEjhEmS_BF=gp2;l{ zunDm%aEyq3HLj%za-y2T3WdFdx=!-AePYS9mx1+3qLpmr#1J&R%kwUluNV5za)gHg z>zz~WK(+&A*-{uCfdI!5u|GAqeks1yav)e1K@07tv{i|pI@M^-9rx5MUMQ5j--)=N ziy1JL!)On~+TZ3ZP+Z4t3O6dW>4A#F=$3*ZjL@{On*-xXq{8bkE|<%sN1;0KwT#q+ zn135;0^^Di1pWrDo}l0yOgKdFdV6RAa|c_|a1)C1HiYKD_o03wrTd@fOJorE22jUb zV_{k0WJUt|AR$@Yzxml{v}XFo5qd-QHDvceW3(79heJT4lCmrYF?%w)qu~(rb2Lyh zX~ug~=*G#B#3fc5L3ds;ZByu0m(Vp>7+qQ6E=m?+iS>(Zb3Wx6)EmS25I%*F4FS5q ze8{7O<`glH5~jpFN|;>yg4R*?QrPUH{{qyU5~ z7E!T1L!ULnE3c?qj0QIX_W?*J{vrn`nB((^K+z6phk)bn(X2slM>W^NIJpt%-E-g> zli(4=bX?u*6<{imSsihp!deCOM+t~XVJZM4dBjU=_UagT8yL+yX~a>YFBb*^t&*4u zWbg)Ma2PZd$ZX^n2)g7KRFA<5gYRQNh*IxkAShiH5t)D+F!t(MM7pnkiQ~LEYw&(n zkM$C4aIa_VT?nSeH`Tu@Qf2i)20y^yLkRrwhdZYl>Zoesh|4mxU60Uu%44?B;~!+~ zX$J3R5Ch|TcuQV*DjOpfEf^X?ctw@oM`>knBIL4@FnJ^D(o_~Unvqs~cW#6Jd;iND zJC(vHU4_(6ph%9L2#}Fk9#u@+6d7H5HF_IvqO#D#1Ui^-BUK1DQbV;dyN7P`qtzwM z!SvB1bwW@0Bg-Eyu1f^47-e}TWQ>&NTY;oekX6e9QAxI8s|KwgwF}r_k!Wx`CEKf% zl%}ei+CeNIl+mObOBV5*kZm2$_r(iftpk(=p5{VQokVZ?AVW4l+!R63&bV*QEi8w{ zhHRDwP##BJgcyF30{sJi{z#+Y9MN_uzBxOq(+>t}_J?avHIzuclazcP@a5k@{qJ`5 ztL#O<817;(BYAZ!jqo#bu&D-5IV;6=;~6n=sV}(0Q0TpYluN8GXWBNUl`!DAT*6ik zi~)=`3b+Rwc=MGfIYyyGm*J1`kXg9212^1-z$_DN!7Ny*RTCR9unuQnh^xh&mUJ5S z908{Aa})(a8wpdI3)cFTDd;G|Nq#1HsQOmE5w66JtM@y6roL&vg~>rLg`E-jX_+eD zGXq!zQ={1%8kK6sMl{z^I*pmJp~sVeLtPdKm@LH~po`dh?FR zaBn;yXfCR6p}G*$08;*c-bTTI2$^^~L_O}G<7NQ$j0C6$w#yoo06?#fdRY2zje3B% zV1nAqA%p>;`!K}B_Krl`=;c?6usP1+u-A9PXwc{tdxbOP-o_jtJES@Tsb$D6n0jI$ zb)3AINRBeoY6qv2Mjh%J2G`@|{bHzY)F!C@Xq{%gg1z-|kNhl9Rsc3r6Mg(#m1F!F z1_Hs?vPuD2fnZ_pu4kHnwg9xiGO<{nu$Nr2n83EHfe;yoPQxz5X}MYGM71P(YE&bF z%}}7anXO7Q(z|c4!EZA73j}^7!j5(g+;XYCe2A6|!*Jv~0pRX!csgfoh9iM5fip|8 znc{|ucG&ji?$prr*e3&zFsI&Y5nYBq#!12ti#)FdoL{@;je8`90Ex5+&Y%A||WO(4a#CtkG8?%u1e%aBr9)4}Hd)q@?0MO&mj z?_h5<3Ctx6B!Rm@0-3Xc1lmaCPfemii0?nkxTeONGs$WjTi?zn`uEg-7aY8YH4fQR z{|sJ5O(C;8oU@NG)}}m@x`kCg!uKLN8p^&9vHTKSy6p-rMYzP`fCWFsiCYfiQU_Q+ zbq4Bo2AeTTzs#0i#g?uaI`R91med`rMS8yJbaTEw&ibfZm)Yaw4gA9Hie*5^<-}SF z9Iyee^Gu>oGnMp{4?ghV!;cRgQz5cIPuT6pO|yru=1Lgc|!UXOB9$v!$Qlz0)T za;~Tavp^CDoLj=#6D?(Vz^^ARyk(UMlm1%XZ=MrMaqg#pF(5NAHfas2=jXe+G75Col;<7yBC(&xbt zl<>K*yIlI+5#8NzWtqdj$1=8c$x*+hrIo+KOo#lipV zCXXU$RvJ$x0N`aB1KuTijRi=voT$2IT^@sCVX2%OHK>XeGp ze(=mGvUk&pbsrIAH6VfH7K9f^$UTmrA5ck-CSckJO&txZ18u3h+k<5~F42-g*C~cV zYi8T_NDfobnWb3*#@5i(Q41<%c_!4Pey*AmDI=HaR6d4woWpH<`nZO`y=Xc$@IUy# z!;j0vissX|aOgQS(iK+nAr83BL(FOyqOs}|yHIT*N_(IotK=&jb;5)`j%(Vf7ia}! zznHPi7i?^ZczhP-2orsq|KY0$70DP1lue;0#vB!T)GJH)!M zaUbt*G*}Blu<~$F32GN~qRG_?*7;=6 zWTlCUVD8&&1O~uCB|xR_q3Fqm%Li#y>^A4vX*_k{{(Y5$xL}fe`~CYX56$8qEQw)1 zEbm_F|MiVRr(R!Hs4jdJ*s9*@o$6H3V599~5Q3KxWUBdxtE zVG(MG+Ooj4Cn0a6@&LY1^u`kBCPh>W#qp<6@r7*P zV$6io;TJIGk+mldzR~~{UhvqOYm!e`APhxulr!Lr-%$siRO$kWfJ((E?m@tSL< z&|6bP_K6;QYA`zYU`=HM^EY7lSP;hq)mhk-<5buz#so5?cknS8E*UA(P1@H2f|oit z+ku0j9J?BW#~BPk8h_1&UPe7UTY!WAlmd!85K^U5G=u>Lfr#aL6b(B4b+#T?8I#tx zqph!VVl;uUjbf&hRZOkpk%3=GmDJmiJC*ZC!X_Vk^wqcBrv3qY9XGitaI5cT!5Icm zFraU~4JRwU3*f%})27@^S`!$2raX(0j|pcP@_p>>7z;gK+R;1L`qiO}Er z$3o{2@rje+%?3xk_2YD%!VU(B1C90L+C!v4?oy-PoCUQKQd5avkDTey{)Pxz9zUv< z8zJG0-h$MuuUm|05y6f}>y$x+KgM}s1t#1d#Om8X^-1dkgGF!v zW5NO8<)A0h{^}O+Dk-%j zGP>kMqD>wp$@7>f2_IikpCvWwHT(*K5F6e$c?0D2zrz`0joi(4rosJ(saAClkhY_} zwU<=?%-cyn5`E8bw07W3Qg$^Q4@*X32g+pBXT4og7wUrqFYjc9dxDqFKrA1eRuUFU zBl34(qV#OSKg*kOGxw1nKe2uDmSZsj<_qRca+PVY+|s? z^sDcvPlxIXbEq?H}K{!epPo@BA1^R@PI5rg8%osKc@Bx66=QN1p>14)4-vI^>XQr6frp}nSR_BIJu~IF4 zni*%$TbVP42cuyApHHZVaK0HW(Oh_ap8Hy~pHZ=r>S?j8Q~^p^BI7rg_pAM_9K=E?L{QT>QF z3hNATcutF-Zvwg2U2NV``ae*<)~$!6^n^1aQnlZ%vQ0l~*v39h9Ji?`5`$(|;m@ z?C}Xf3tWZj`*0g#7DAJH5A)CPku)$OK;M%HF_0Ki!_FCa4uGj$YzU$iXB#Xkro#04 z1RT5`3<8T@|?D7(`Ne>Shycsp!k*GBHm?A#8NkfUYe;bBJozxeIU z1K8@ZTpdLQejz#wwt9+3LTdLT5Uh1yrJuX`e_8%BFb;oL2~ZkSzzdU>aC-BR9g1i^k+F@Lx4#Ap$ND;e<9oj(N&C8qOls4Yagk06s^q6Wv5fN&w zJsnz_>L46if@knxjS2rj6g@@d=`jIRn*Nl*Er6D0r_~SO^)9{$r?Vmgg38FI2LpL) z^ RcpEI={5xK2DkELnzcEJ9f`M5p=F9Uc{7l{2{qV`l6`c6Y?Gh8d#=%LoL*A* z{5pq>b8J36ndsgS4ORe?eJ6v$C#DQPuX#7&$S`}HB(45yF~;@akBbotTqIS)g$=I& zNpBh211y80Xlw$2ji^`YD2}&#U;S;vPoj3OX^iQp3 zQ5HM@b_)M>U-Yhe3!a*U$GZfM$iY)yYg>5iW-6`yeyrO`3)qVEb85TB2|va$6dJM1<15kRAj0G&NTaKf_<+$cG4Bq3W6ZY(Hv ziCpc=Hp*y=I|`jg11!`_Af-_hPD@BS5hECm3i_|1LD-F*yv0zHt@v>c`9~RijKR~) zdM{(LKM;tJ&A?AE?I#h06DAu9jxH{S$GHn`fr@GNt7rUH-e9# z;DBtrX5lGmADo#08%Ju);h$$KR3XDA7v~R(m&>~p)9XkY_=WR5c-{^I zb)2citO#mKZy@a%zj$Bhwp4J_be+54@W(hWbRyyYG6G-+D*ClyD-XN^MBs;wdPewV zprp;6cq!ChV=ga^sm6h8k@UfrKO~c&q;dzpn<9q;;7lZ(P$mibii3O)BX1O#3KGe{ zYNiyG)+NW_wX?iB4lo;QO+=U_HBG(~zc2oZxEAU|G1ns%M@N~jKF$dj#3J?tr;z4n z>1-%)7R!HvXMUEY?_(fBXz>+ZXKN92t^FAI@Gw5{2a#R?m>$< z4`{*hghA_=k8`~mA7W^J$apmIL`pkzg~L}ml8MD0f#E!?J#2hR1C~J)HT2S$Bpc4$ zLi-_$f(d?>!?7`$H9Vgd+(J-D;<*%`TZ6R->E~G=5%TX-yaPPxg|beF(6p>J0aC)r`DKtmeu?jgFj;M#|-{CgSRsHAqGFp;71Vn zW_=zP>;iaXWf=A>0bgmKg!UdIXYB{Pz6uaqHY!#TP+LM89PLc5gWUj$!d3y+oLA=S zxQf5SC8GH{9E_cmm0d^o27VFZB(3h_>!`Q1afWKZqDr)w!+E^PP?$Z~XR*)`bTfTq zfZO9D`k1C15|ns$;~e!s$WU2>y!vmUdJDsUq%q)Y%lj+whPbVNrrF ze7nI{HtdQa8$@k_UFkf%Bj&3am!>pF&@wWmuL8kqNP#@IitBU;_uz z1mP!n6mUZD2SLHB_@#g-7H7`kH{w3{oKM5&{C)8D06u61XG4HdaNZpF%7PLJLS7N7 z=7E>M9l_&14$C<-2}j8mKYs+r)smJUNE?wf@N9kMfcx-F4j@qAhg({Cw7<;%z=HUs z67a2#k03tcjjR^&sUkSSrt4{x2bZUw>4B(%(;5Y6@#9~g1U@FYmH*t>;?N%cv=syM z!Zha4vH@+EdZpD-2WRMYi2G4&DoUdXv^WOG*JX^;IKKt4GTJ^AjM87xV+O~`{VQ(- z&a)w8!m+Xp;#-^%jKCOTlSqwI%BXEC#sD9968E8HXH@==@ju?^3aw%kw&9~W?c~+piggv0502uzY@#=AOZ&eNfFpfd}>QzVM{u&PE;PcR)d*?rXm$U zb|7@`%hP@`0uj6RvVpo#8C`Pp`4HjlarzJud0zKAeJ}1jbTn}35vI4)W2~llOgC@< zZdmcxEz#vFJbUPSI$bsBXumc`!@r{t2oeeS>G?)u3A+Bc8Myiw+Ki(+Dm`PpEEGEZ z*sUdPNBQ1D5{cB6pdDfOr+GWWz-1tqN6;5Bk3UI#1_kgX)t+ZcO00yv_CGWF+qOI!y1glaB;4ALXW!z?Lembeb=<}!=P z-XjBIEDo6c@@xm+q~mvvTAI^$oVn0!0(v^82)GZ*pRKyn}~?N2i4h`-QtF1 z!UPkFttrk+nG;FTWh8U>pb3O+cE-*icOTq_?*fXR#03W>K-%~e$O}7C8MA_I(PP12Y0f1#UOE%E1w(MxRs|HB{S z+ad56u!vle7V zS!h4G^MgDKnh0_?&~du2g$)Qj)LPguK)MRTNtxJ4dkoebkanGe7=!m{$HwjvX&%^r zNh{%QtlyjMY)THOX1wx%FWYgnvgA#$9eUr3+YxVnTiWBy3)|9Z(%PAU)+Xgx>Q<3` z=V9yfQy86>;TDi`fN>7!jR?doU{TxxS{xuc0@Q8+;SMmR-2q-OAW;>e6L+!tFy*f? z_BjMF&z02QV1g792VuX-+s`wQl`jI{rN2hUt}O#nsr>K_?JE4 zsNZG*IVzBv<)L$TCKuE~yZ^5+_Z9}f!{Dn7eiy+^J}d+Qws;=B5E!0FgEA;GP(GiC zhx9YNjlBkbmMG}g7zqFRX~tv(DNp!0j^LXR$-y8uknsWKF*eM}tT-Ztc51kzS>{5A zvFl1RFtEiRLBm{;hjCr{3_1q*OsJ%Tk9hI>E*?CwyCad`S?h^p=yw)gXANO}h0P_y z+e*5x0P_a1A@n<7+=K}Zg?7kQuajJEy75iM^TGtLD+7aYDFY9mx(0m7QJ7X9#la>& z3uYO6778CvJS-*o*(TUE=+f;Yl18h9D6OxCJ9P~xWA?~Si}%~GCrrCLp#PNOGGba{NDye=E>>Y3o*gNfu z7hfj=dp{C9t4Hli(rYlS(Wv!50mfUNDV67;up%(&uy@q z(4KcIv(N3=U~-{-@Pbi04XQ(|-heu`0KJSUFT;JO!a&Cuy|^$b8J$|M)b zXDY!%fpLc%=+GIn6}fqCGKesB?`%a)>2z4Y8%lkN>*EgqkbH@#Z(#*uDx&_7+3<3w zTas#cLrap`=h$xey85wzzrA2-fkiBZ~g;0=RmL>aQqF9g0t$;HJUh_#$V-Y)LI*I4A$4E`&F*C2>L2i2^$hH+ZI8=?## z0YMd1S_5T>g-A+1<-j<7UrmV9#$EXM-(GN8uu}s+MUTZr*p$#)b;MQe9ELJN$x5c5 zq#p_WN^EW#{_imwKg~`S@j7Bbk|8ez=ftm2%;TsJ{6Z$deIt_lxhbtXz#fVUQkWUU z4vOKTSR0LB$whI*EB0|zedqJluW;I~MVtDJE<7{e#qLq?8U7e&_bgoQHsrn}o&_Fk z&cUa?@GLew6$49x*Py3fNT=bc_-Z^0g`C59mh?;KS+@5TJe!Z6;&3wRgc^0>VX zybIdw<6X}Feey1Z5ArUeK3(V1I0=`=7&dsp%7Tl*cdX#xax<64zlu8g4x;}BZ|~*Y z{v~6dLjd18k|0^%VCpv+yc}L>_UReFKOpYVeE+93<455%s>_^4k+A4AFN z-;$+>-z9!)VKctQ_m9|&F?L?q4EZt-*$a(3uEb`*OR;6Y058RY_M7uz6J;N1HJ-oL0e8&Y&y)tN)q3U z!@j~^1Mh>}0>p9|DkE@dQ5jCtfoFQ3=0`#*4(f`1u>2Hz0=h2JJe8u*eD%F^bOJ&yU8 zLcoJS0Ed2|Om7T)3Hg^2_?Lfj;)hT+ss0fB3#|Mj{skuhFdCx+{L6RbO-kXLuX|z) z0xLdK?&Dv^AlR}Y3?l~v8&q(Cw23j|_|(L-`U2tiw-|hpfnez`BL*SSeN1=*ga6Ec zOo;kz247+DI}ELpgd)_Z~){J zpc1?ypU@|L*>jaOAFvGSe)jF_9Gr%erMeF}FxO9Ivfx@z=GTS($e%RM;kp2d~M z%79)QF?`H;EzodN3V@V6y{82_nDF@M|6unk48**F+MywC!donG2ZIF!e)=)(c^IF; z5gA{f?&SmvDF(D((W;ohl5(9YMs~K81Niyyv^+P&GM}+L;5Z5_b5i*NDxGL_Z#0fZ5(4wdYb^?AA*X&Tf1gfjH+un1voQ!m`-P?P#= zl<3|N4_2_a`ei8E{dRXJYYv=iACKK@Gp!Y7io5J}nGQP`NzbFc!s~+UfdrUhiTQc< z_+ADg02e5`NZ8@h)BqI7M`Y(CNKUlDcQgN&S#nSY&67Evt3;&y5tb2XlF<=7@lm9` zmmozhoQ5>;gh!u004thmZWtK|wrZ^A7y}}Z%_YBIM{hh<&1FxY)uobR+6J`4Y2r^k zfLwg-pdUwBPduSm4@-#x;t+BJ5KF*k|TyisjdC=rFW7cTm>qDO~ zK7c!_B6Vj?C4KlZo(Ox!r~TACKFK;a^b>bK%FYpJb{XAm8#)#ejZMY$-@rTnAgWN` z#h2$@9d!--Lh3MJ!@Tax*W4=hdp^ngBYhjgw;dGVxk&c(A{97a*JtxC+$+pSov<1H za)^xJ;^?{WL@+qbbNJSzOf!9*;>*@Ooa6ylwv+g#1cWUR+obZ|3*dTm*z9 z)7Z9(mEZ|Fxqx62AD5E!Ownqu{R?lQfA5BfRNK`AZc32-}U& zT2?V_)AgeM9dhyA#bGjInQKunZV;4gs@GAUVts!HJ%;)Zs6hQk)?$dCQQ5yMk($TM z#x)qIFB2nt1_isX8N!uDosYn&pKgY_2ibzF-(Zshv(ikoYyL&H@SAMmrkB*h04X4T z04foHRi9@I|B1mLF!)0TLqO(F+29`|vwPze8w@e#pRk@s2sE4W+JB8AKhJ9alx2Ch zDJ%^6(lM=_vhc(p92L21ShENAPT-G#Py0W?Q}81(VnmuBp*lybNc;7&8l!S_!XM+j zxWruwL`L$nQ7e zJC0`p-I&Y?;Ep<9Y&41H6%iWLVnTcc_X@K71&$_F_M+0)b185`$h;P-PgWpAMaJai z?p@(z4jGm!g6BrYfVf?BRelbbI5$+5A|$vB+69uar1l#e2Y8IsD!=Cnj|qrnaqffGeHKn;q+k?_X9BZ)TYMtG zJgfx+B}`a>Mrr(%k%1=$(n?I@8oL04`=3IOz>vsxNGK8mBuE_Z8aNSs2coiF!g1X> zroPC83EPyO1hCC-6l09_Q4?+NPkYv3 zs{o(xl2t%ndvFzu0OE@{MUZI0;2WzzAbwriShxz*WegHzd;-=<(3zJE)(I5$lRQO! zK!do%53V~mB*3i58ek*5Z4~A|d81rzD4Lu*_X*?1n0KH?w#JmJxK7{%O!6af} z4qO)ytW7#_-W;rziM|6Rr1%PH2C~pC#)bfMi*H13nZV56N8r4DW8Xvm$7W@lB5uyg zn^`R`4aVCQ)2;@_>Rp)Ht0HBs3>-Mn`cjed{{g^#nV?>L*^2{;hgJ#px*!g+IAFCa zB%(iUd~k6D9_2ZjFKcmVkRkRLxY29ncKz>FFKd}BbN25+MOR-Yb~0Wv{qc!3=)pA` zKoCx8Ucw5RM$?esT?mXoBPNkr34Iv&U4Y z=E(OUx8pGwFXaN;#U?&~En+nxbF$C+Al05zXXoMZ&{H*W zRI6Rte3$@qFPigh9M^KhhnR1<%Zq-d-8fan=`mMK&c)gs=a*3iu5cFRsF(UtJRC~= z@`O)*Vhwia6xzzp>ntkP=eM3@6Uk=VY2cfd`uM&k+C@GmuSDVeKjkakh)4>ec^Z39 z7$uz1(tkyZj%{|ysjY3u@0QlT@_@XA;LqzYE|;5b9yH|9;-0Dh;~Z&0PCNzovw#~M zFoxalHXJhnHw`2TlnU++5CAvGcw5Sl1P~a27vb>1w*+LI>Ss6@w#*X-P}P0QfzuvQ zd#vM5M)1m?4UyY6Sf-0VjsXAckD%bHZg*JE=5BY0Fh0RnxSnY=RWM6-X_pd2LJ`l@ zJ?dxiPb*+Ph!|j_qw2@^x7%9miwdw^YBbxOC0A53es*@b4Og-qd>x7y!XFPG*e@AS zUep%+TpSm#3^OB9DFB`)^w!_8&^)$3KEl;-cnZEM1r-sGzje0ORN{6H{n;P8QB)M_ zsI3fM$KYWGB#YXS_xB*?ryl7z%kUUkW%9=ve3HRM=9wLrpOZb=RGl`y&(-kM3(WgA zgEXd4zJ8|Vgx_M?XBiw}##b4;ma(@pSY*&-FJDuxly926?SBJ*_B5pc literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/util.cpython-39.pyc b/tools/modules/unet/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d500b1fb746cf9e381c45ebcf30c27dd10bc220f GIT binary patch literal 44250 zcmeIb3!Ge6c_&!+QO~Zf?yhdN9+u^jY-BaajU^icM))OLMqoE!$&fVi6s5Y=>Z~3PPz5oAr?xU)@S`WhRFFP|;zdCj9x#ymH?m6H2-siiO;o)2YpVhHX&fPg~B>p`Q zx<57!?!(XBXC)F&!Z8*Tv-;ndHRRZwHRa!$wdCIpj+0BN*_6S$=AyBbo=r>MTFfkE zXS0&G7jsKPvqO?M7V}FfGhrl_hGz>w$w*K#Iy)*Q$;Gjy;%rg!sm1Z7iP;Isrx&*^ zP0mgt|1KxvWS=%>w>vp!=;_4l4kzyn4FaAfw)?yK!<+`M*+ zv)j3G-JE?D>g{oELcN=UdizmQa`vKRZ%}e8O7=N7qvYnG8kD>yC^_gHbZ)~l-i9;N&TDbzwZWNbl-%wdLdl_^ z|ZyPVgfrwK6a}*^f|5s^M^W-$2gIgdiM)lJu_2=pQ_cH?#j}J z;b&E?qEw}MriSA5az%Od$|ABOkF~stdh(&A(>2GbH_xD`Se>s`7nWP~rdK|%D$C0? zh0^4sZ-4NiN2*wX(I<^u5B5&m`sy$8gN2Go8d*Vm*i{zlk8ZX)RDy5>5hdbv1Y0U+6_Ov5s|?-e(oM5 zbu{+GB$_c+O*CPgODMZzbj%+x&)O#v8wr&~r76odr*HSIh1!#T@?2$crM7{QSyr{V z`UR=*;+;qDKJ~a;Q|_s{TW_>ZJ>IO}->ffHyxJ+R)mn5E$Q=+!3c<^{TC<<$HGGHva{`M0?L-tKyU&SZq}nrjH=4EV}nh#}%hcKUw!`ORk?* zwdKW1wKkRZldgwbsRFL|bIscMvhJIolkzg}qK0r%>4o zD3XMcF>KQ|CXG=eZ{!TM17%f{XM)dm{Bw|G4yQ4=m^ZUyG!oBZI2*>Qg%q=d)WS>| z4I6W0G?J_7cy1l*jL}JL*Rz|Ea$Dv+lJPHJ_?OLY=X z+R0otSMwcnE#1kl8%9D(vjaaf($cZNTn_@wE`zcc(v%PCy8x?AlGJJ#zH ztHYh#TINzh?vwR$os9I&+?hZQ)Z>H5T{JHm47o+tve3iyqzj(S5Ki=|JzxH^V= z7V#U8@9T_oinzNS+;<{U#1*cWyP!9d9%nHd z&0==bEZ)~Ui;lG&cu3}VU*BAA>r8$;5zS?Ibj7*sjLY+e^J%G1NYBG-HB^L0S8It+ zKi)Sc<1!}W&sk^xVTUu{JI^){9_D$$d=9z4!<{*rcm`+ejhypIx#uM2;V<#~inKe~ zzw$aJF!c6b3{BJ$VQ8YEGsMvORU&v}BJmNy%NlbL>ui{cteD8#W63Uz8(sSX2}lnWxz0W+B6oEI6LwjUt-2Ap z+jzuYz3>V?XOO|?Ki@lB@ajvoaxJ8&f}&@Ko{URo$EvNSSG(ZFMT)&_XKKxA%c-g9 zHVD>~DHK6%!x`V0eu3Q%Zj!jKonERt&SGu4oe46BqTJzj#;NIce0#oDai$8h8B#7# zlBL-Z$*G!KK20L!`bChV<(8@}mMflDYkKuo(;tq?(9?6X!{OECrIkf*c2o{gyIFQY z;A*ar6+gAI?0}B>DY6Vr$IsT9D@!%SBCB0r_LIVX_|{^r>5oS@kY+ey?$mZgi1DrJ z+!?<>Di_Zzj#rk-aj(_QsHtvY^$ea$`sELGb;;*g@_}7cs0F!UcfO^(>WYkn3%bkR zt4XvnI}{9kx$a!>jk=$zH+i+6Zt2>-(eN#vvyqf%b9&Tu?egqUFyK7aP$KXmaD3wx zLAHdOyPramFl{qsL|$bU{$N0}NXW7W-tP*RX zPR1A=we6x@F(GNe%v+K}&4Wf!+Se`Mx}dF$fm)_*?Y4q5X2u##qX(usf?oD84`lwh zCXV9gei%sy?A|Ui5W?=QCOl&ed8C$OBDKNf8EeT7i0@j8`Ju!***?oN8fl!*bgXp) z|5^T@F@eZ!Fa}^g9Agp;#Db+h#F`x&b#fh>Y{dt`ayr%r%{h?mo@>DWO{HekUi7Ly zdMkj1{GwjQTx(N9esT^osDRJ5us*SFky z)6cbQs^yj!>kGB1#ckCc-v~e=!XExacn*s$!V{mt zwJsMx!ps;M%#dm9dD*u1VD=9f>a8g6!D4&}q68WH3^HJ(5-VW;H;fqp!6`!>N3jOS z_};171)x>;^El=5`h=k#LE1B~ERWjexOr&(qM`Dg#0ku1$3T7;QXF+~NA)&T0lH#4 zegtVbP~C|VtM^P;LRea}*Ar_=z?HEMCU*fMO2=ADttZrJ2ZD!V z9Z68^aPVA&9X|Pn>8G69xq7wcCmrue$U*`=hI$Mqk4>f3J^1(SR2kJgr2 zkOk-F>eV_Z!BGIYdw)n2UPwTqn0bMq_QHKgzK)-Ks_8I?f1c))TKKsiLgFPZCf1B~ zXjLHF*%OJ?v}Z!bGcID;cYqwqR>wwa<7}dw^io*fMkh_F@iHc|oRJc%t0cRrBK9}qy2b{~wb*WBjz_}dc&LJ<` z$#k-v+(ir0+(FzwzdF2H@N$izM!o~-u>)v9FNU4ugNd6^)oB26&GM@c)+K|Y2~tZ1h{_Y$YN#bv{Sj~e@D4XEl18MzG9mh(7%b7+ONMd z%mTrluPCRR`PZ!psuq%uyJm$sb}u`aeTN@7{-mI7Pl@9^6Sj znRzjnD;m4yOC?jv$iI@pxlwZm_)l=9U9)JuVy4=!x~7?&N~(W=o6V>)QuSUY6(m#U z%*;lL2)kKZ7>E9}W;@wYInzB1-|h3v8AQx_Sc1 z@8ahY=Vi>osFYN3(DMl~qI)Nb`>3kzBuQ0kHs~rTEEuHOq)=E)g;3bEP}qsbYKs-G zey;Y8mRqN!^#I^e9`l6~T7BWWk@TQm-jqKQVJ64VOhwx=q2{@Mo@CAe98{Xsnj&lEXCoc3zg-30i~+dJl`~~)@Ds@V$833x zi40}5n&oaVmL%K^P9{=05o+2u4tn&M-tYm#53l3zD1PogLgG*lnHWllC}8eRtQvT1 z8>}Eu-8v}IL;~6i2oaz_HmMSrK@(yK_>E2yOdhn#IEO39{&kW!Kv3|Kq)|G516t@L zcO)jHwe93kD4sACj)FQg5mV9~-XqfiUF)RRsjC6|hZTFN8l^X3# zow<7H5Y9(~+Ah>fjnX{N%;Q{RDkIaT?q>u4gb5d*`e#g{i`1uZh?+vkeu}3k5BSNm z99=)TAb(hWi_FYs)p9V{eo71wZr}6>u6iveg2XRDW)0IoQN6u;;Lrz{dCJl_@{{ZV zR|+01&S%ueSpFkSev}CrL_bfhO1-*N^X6NQek@{2^>HS$iu>^6CvakfeWQUQV}X-5 z65lKgfdGsvo~z>5^C3IJk9!bFgeeV59U_B3-w$jFUAdEWaLoUu%}2(OP`6L~QAR(GGiqhf1LG=Uiz(Gc6Si>ddi|)f|U7_%P2|v+#3&4@pcWJQ&6xuJ7}#O9{`uXuA6@ z!6rlN4hTo`61eq=#Iw*GH^6N>5bePllEItj40P>`YOR0+%QIh{uRu~- z1kQy@2D%(@V-TE(mjz94SpLvLO`%5Ft%HS4xt^*!U}cvpVgc>ZvQP^`)0=SDWZy)e zy=N8Z>ypjEyM$rw7~q`EOQ5@#(^|tlWq?$k;~iI!_+y~7REd-+_>rbUg%lfE!543P zpyJl6UEnPS#0{+-S8qof8)@BC$`Nr?eU^>O9nSL(>82FBet`fMx2+enusJuP-bpTF zBIJ~5zC3E!ue4vgdA&tMxtpT~t00!njq#YJ&|GW|Hd88oraU3aj9l7 z9_A#?fMfRUCezeWfc9YrBac-IDO@e^QF4!pvGozOq^4L<>{@Q2(t3zWmVo>7Gjcs8 z?OeRPrg9oN{~DT5|BA_%*jUo3Eqn9v+~|S)ZJz%XCc?{nkvSPOV$AzEXm3C!VS}L+ zZYBd}7sL`wuzbInOBFyoC(YE$_A99*)v)22dAg z6svZ$W1fML8`d)Ol)W0t%%FRqfY7FZ0%H9%Z1)I`VW;IWtTs@WU4&J2og5R+s@bly zSXGUzSl)KPK9*U{LMEm(pV{U@`YnWcA{uQBi$AT}rcDZ{bnvO=SY zyETob7gH~dzH#%zzPs+t-|tkGp+Z;x0{7{D00rMIG9&!nHpJd>S+ zx4{-P?DGK)gT^s#dlL;q`&-d4^)q)Y#Jml+1Soka3vvW%4>FXC$WXc;)qM_aKqIN=N~;34g-4NOAZWZ87SGomrzYkwkgC!^ zQq?{GU%(x(4F4Q|>d%<`Z%jxC1qJt!E{&Uig=Yz=v3d1}Ll95d<~KpB@HJm*-+e{c zyCUgQ{~GO24G$s<>UVhSh+y{eG@byy@^Kb zzqSYaua#V=-{pAvO#TfMfrS_#h>?H%x-jyqt~{P0a)z<&-}T;4W>?6-pWws=-c%Ie zK(r~sAc^}o3i}n$&KB-Jd2*)wj>jH4di=!kx6kP$Y?|hrdkX# zIl4gs(Z&XuA-~XTf$3L>Ra3fINz+JqEJ*YN$izq#>Iw^nmlO;yU?8Zhr}AW+h7pL@ zNNq3K0J8wH4w4{d7pUt|Vs@2eHkD+dM9JX?RXa6L$PIOp>skku7k1sAq3RHyKmOb? z#ZkfPtKm;mIZ?LW$-{XjbLQ|Xu(+V*+bBs-7c(wk@!yhHzx*N^OXHm5{_fqTw!{{P`6}gnnT?} z2j-4-qcIA_TgIa^Sr)`4M}uw*x+R^!cJP8EQ=3UH1O zPY)V{%QOYIx1EN<$Tt!9q%jH>=`6K0C#H2NmnaYX>msAK=Z{~$26Ip|{lL@TR(B`nmf=3I8>z(F=N!vmW zM+Ku`iG;t z>1kQqK%#=uGhDzuC_U?^11@={2dS_%{sgI=rab`=jxpR%yXG3Z`^FrQKuv)cIQd%)F-@JWyylxC?nd7ju;AB@_%o*MH497{y7V zYff+qgtvgHZjS~-7~yB(pL@oWxP-r>__>cE>Ch?wL@gtAK>XAVbpYdv5d`7}lAeHX zGNokL5VWjo0cA@LWr0ME`Zl!wK=h%3C8hiSGH#*1fy6g}GtSf&SJXFg1Vfh4s^dJ9 zjYey>YaC(tRDZ;)V_1w9OfYZ-jNLX~dqu-(G`P`{ z=vfT}EwYD#U7kk<-UL_?(CB;{{Wr7>u2~x9$*RD)AP0sr3HCv}=hahQ9{v!SwP6Q# zx-~Ecr~!>I{2{PY0Qu2%dunyw=8l&P5k--xb!FsUz=xSzW%3l0 z7{K4hLo(k}*%+Z|?b-lJFCGY8q+JFoLTWc~j5KKv3482JK(Dp|ZUfWTH-gto^07lO zkz=<=g5dx#qFAyCF_NJ0;~qv5Dbln_E0p5$p%mu~Z!&Hb{5=D8C(-H^?qufVu_~b| z{N#{{7H=;~^h_viX&kWv)vX}Ih6TcwY?|5hbpvYepCl+>ob0FA6zo@ha{ z3s@z{=Y)08`Jt|S9$bBZ5;yFbjieUAVkn8!C$L*E?^|aUSHfxoB()xt$`Kc}i=U)@ z`<0A2id+i+d3fsPjrOfqp5z!s z6CH;i<|Bq>7Y--@!N(P>$sE|ZH51Yw*q4j&Y8G$X4NE!=XP5wM_&MrVVbp~;5>7^J zhhC+ErXpPBXM&5$@8lEV2PW&P{SL3GZrN{Pw9rfRS7c%I*i@4pGk`@fR+`14hpuL> z!~q~Gr)M~{34lDnp{@%AzX8VP_s~NK{=)>4VhAaUP9_Q)hM_cMGQou241;*7lYG*Q zJn24T2h{fM{a2Z%;CaErRdCtN$&<&RxgVsw*V>HTyzxg-b8E#tb{{A1WL)q|K>7DPjFKL;JW+OpdHxg!F*63tGaO)^VZtae31<81n1{8$Wy}M-1y|lq4k_F) z-H$=_ZtqGo&Cbwj0sh-r?Ed@=TrZ7Ip_9Kz=5Wjbu0!dUdrFJYo3QlX_LPo`#wr8f zc}`@GNuJ4GJh)#74Y%5j*B`0UHChn2-Y_#X2W%BW%tRNf_yosSjI{#1*j64us{pTX zj$$2ADPmO3Cuod=G*s7k05ntE6 z{+#!UOahV6LWxgp=S9%?4CG+nOVFOi)d_2B^q3NdR$x;BYxdDMGtqwCrLJ;yDq8qIy1uJ}a%6V){jZsS=}2PgL$V;L|d~tc2HP% zjaBn-RA2Yy7y9al{fMqYKPCp&pPC*~9wu9{nZL;1XV@hpU^8zV zxQ?d7#naDSsx8;fFQd^#7#PB=2+!4h(hvE>P_nBXCwPIOrtlYa8yhaP_9 z@qtsz=G+pJZWd7Mr5$KawQgD%C@|}LQ4d^m4S$+77A{A8ySVSI0jE`xUhJ#G-YR%O zr}@wA2Lqt>OMGC8<-}d>g7h$kLID$=#D`sE#iRJS`;h=Ki;@eZQ@G~KCfuCFg-R>Y zsqzJyhIxP;RyQ%Z0SPqKw;9NPhy}2u!7zd(=(KF9goc;#ItHNHFUQ^8VYG*JcY~?k z+T?H$?SC=6GGpXz#ga{KAbdi3IwD&>9RsR7F-mnbH`tNSf!z<}4^L&ZkBH0YEZpQu z7+disjU`0-msU!8^F&E|zj`aGS*m;5krhw73tiyWN3X)ovEptAJ6K9z;~UHUkN;cf z)bo4@vSSGww8c;?tMah>BeRJDnWPXjonHMiP2t_ZnNwt(XI|j7#{{GIsExVc_~Qt< z^GLeElq7xv?x(Q5(&ay}f4fIK?3ut%?B1|mi#yx8+01xkGg7dDrdb2q0APEiU0dpM z*Rn~AYdIGiQHEkz#~$&fr!8^<^(}NpyEWPmKlI4sGO=RI_APA44;{#bNvHcac7tMH ztQtZ#HfGP}0xxEItqoWszspf4OmK~0qqc3nMAZa*5v|NiCcF{pvXFbJJ<+968N7zj zkc^=~8PQC*hC*|Pq27S5LBA98&I4A#pRq%13n+ZN`=%k%!RoC|JQs-qz|q}OYKi+a!q(5QQNbY;!u zh4g!O>u1<$Ty^NdgQdgx!D;@E2d7KN=kNz#cX+tVvzNL*U8B&dR#z073m*mE)VF!( zTP57E)^srl!NUk5)huAo@+3y!&%)0=h9q8|SZdtFWG_!jjti7~fw{n-CV)+o`+-Gf zX;_ell2BpN-h-GuvY%J9;BW!>B{aH$5LfsUl7n67p%Zu>;~jY;V?#FGqcN7g&Ee8c zC;uxhgQXv(8z~Lh_W>N#15A!GA#}(+1}(q;Ob9KY2`)cbct~xdJr)Aj)@zM}auh$8 z&Ma{4g~*36L=g5COn2g~37AUo=#B}&?H!(l%`d?8sDn+i@-G76gwER(awZaUmiUZV z5KKn(6`=DlrNC!{Z6(z*RiD#}1mW&grCB44^yUQ)?Hmd0{e2~L;Hs;EiO24` z{dRSPPq!CM|9=Pm>dn08y-XfrGJrCSV8smr&B>n5Ujux82v7EHhR^+*;M0J!t>I)P zq6gx7%mK~=sQsQksD)@Pl5;QAs$(36hnXBlGOby6gnjl1U1E0yAtN_1l zy~A{UF!(6$n_}}_Fvzk^FeuUjdeAH8l6>WbVb^M1JJ?N{45CWLyGn=s%wsiIHjGpV z^qP!u3sT*|j@V$Mls8}TO4TkRhW$*$k!NZ`Bd~AgG$=Uj7aU@VmH&z%z#k`R5vT$w9t_^WV9E5}HT7Es@Wg9eKfQ*U}20LIx z5GZsq%MCG2z>`5eN23Ca5ov9J=s7Av!f?gL79HE-`EQSx5z2F6zDx{f2Z3%m8QYvGt!G8(&3Cxa`V^L9&%XcZdSAtx(* z+l`_l2zK!a^QIB`J(wyzpNLv?Cl1d1=#TL8^Y>yy>p4uOKgRVbdnwD@vTQ+n|+4@Sp+lT$ligc?i&R4 zPL`2X^UalIu)utT$5{4uCIdK#yGX%i(aS`E%q_pczp3{_6<y5;IhNEM#gTTD zFTQNP%$*K-xym#qUNOI!%F;;AopbqDa&{5^9__J^e8r4N$d~cP%qQ_Xik~nuS}3K_ zah<#GjSBUq5$3#Kw(3&u(P#(wwNg3>Ii;lz5`$|0#TvHQVq0%%GbPg4g_2N!A}>HF zYnI+KQ@T((iXZ%Wp9tC`HJXEO>G38bevRg% zgnMC~C&Vg+I8%?0cyKG{|3X52H+K8uPOIs&r*XHe`YKZRn#sL$@F0^ou5gvdzDAF6 zeYgV*er1;}C;@YbaET}mh?Q!0a@!I2p)&+q8Qo%ZUy7gv*uyCO%tbv0o`cN@ac)`O zC;-Rp_D0}TgQyUQjA=>jQC(Zgvrr*{rf5qLwTn~cbPy;)nAg21(!Sq2;N(4rosscY z$!cki2as%^My>V^?xU2Fz==GpJvF8t#|)^GOy0%h-ArbYATSl9m7U;SDBWmTO%T9) zc-q7+ShDr%6t55@z?G)TI1KR!p+}wO`6`mxH1Hol-Se3tz7f~^*)*XWfH*6`Wr$XF zj#b6IoKeq!ar?oJaQZaCPHe`&3A?DL;+KJAgdt~{#6bX_fEO^*f}HKf86jJqnK`CO z$gyK^)z}pK;E(w%b4Z0h-5;?WBaj=O)9IT6i_6gx94?A>b z;#xP=yZUrfh%y5UH~FGWvjhwQo~!|Ty*o!m4q}cWt~szpQe5yvf5i}kkpv5kdyEzK zqF-Qqei477c0&GNGS33+Ymd}i!uDwGRIIFOuqOExbqUGrQ03ei1R-fHf{B<4_DggUmOgK zNvvSv!)z(X{~Y@#D-YEm}9VDQx0FyQ;x?7aVvO4`0uOn!J!(=q(+@5B?H5DF(;{xH#^~NNN5}o+_IB znQAybxPv?wm)ueOT;3fE1B?Q+nuqm940lqV6rm09d?pmV@LeSvK^rACmk>gD1szyU zBmOJc2@n3ZPWox;q@vmkF#?j#aO4o$Kz6e;Y7r0j!pn%xEV?1suWW|=kY`p)I7hM; zeL^TD3oRiO5xKg+ww%PS#Jhm~&zkaj1v60ei86C1D)$jKKYQd%ty#OUtnT?$jDy~I zb~AV?JxBpWc5OKeYfl;ekY?zFO~b5laY`j>2n5Dfpf^&|9(#!qUmAv8zF215!> zJ4AAdC0ZaK3u>`-ltuNyUc#ElyD+#_co+5xssM&^SI?s^lmN{<{&iopaJ>yzO(GCv z0{blyK(etT+{`tV)`4;QL}~$4QG-tH)EHoeV;Erogy?3Ja&hewWFYLJ0G&{gND5Xb zc%2BrDUnnU$yECp@OT2%TE5lI@iija#;NNbR^N%M{Y)5qTfKr=Qdm#$jY1LbF4mqDE$UOkN7fyU13@_MQUlebo8Z^&6pi2WAf#&4BH34nM#(d_GAxab)l^g zRM{v2i)7#^e(wK*1ab)w4yRJ~l8J;PWt4V8$ph_xc}38fBdU!pfuh(u*Hi|CK3&(W`rF!U+@E z(x@p9Wqt-Bl(eSfPPRSI$^+2OUvkbSQG@j37Gn?Qr$`=O^Y>7>_s&d2y}MTb6;Tr5 z`+44jTJ;Ry<@8Enw_<6PH|_p}bKScsA9{6c&&G@hT1sys;u*hie;CbI@X}0`!f^Ou zoBX^1SQ4R(2wEV2(T0slYz6&XjEI*978DE;@UVsY^*U6VU^U(89V|$ zU-_I4D<|w~#HV5xat$ZdQ=Du;D!X~etsZ`sk+^~w%KQ@?^J6@H9}@u*!7)^_sppw! z^uo}}3+DoT8H7xEPXY?WC#Mf{1U|gcUDf;8lka9S0AIXFxM1fKU|PAAIK~KbDj@*n@opT%Jw_~I;->;NBl;gW zZ^RpEjCMwX>qn$^9M9O;J|V8_MVj4yHiLP& z$xwUU+dpvw64WwttcR;3`i}4@XQoF?XK+_aLBMo^LH7}Z3P%S8BQPWRDjnYhibNc9 zqr54h=Mf}6!k7eAV$l>G!l58UYl2A`NV$&$wd7}^diQBJ&{bPJ@@lsVBot* zxT{iAL0-~}<@}tI7q0QW4}#s5OXGz0od&}!p9C= zkjUmFPy%=%iCl)s6e`=kY49Ax_${1VHnI)B2)}(Aq7m>L>_dTz3je9+H5!IREpLjL zP52$*LEvZyG6el6c@c0t2qHr5tAsOwJ{A!R5Yy`ec=bXWuU_~7LWF?dS)~Rz(?PTX z;3f+eBj|iZuQ~)A1TF|Kp1?Lohu-|C+2Gs35V>8-hI-0|r3~8Vu6n?8xF!b>$n#xB zjUn9sFZmCQhc}J^x9a>b^26TnS^@9zLL_)bUPpa!bI|tYIv}CosYU=?c*`A1&H?{| zyoNant>6Rk)y@k%8qq_M2OLm1nxVA==GIYAk#Xt!_@KUHv_=}cFy~`V@v^yg!+HWSd|9?z z;K0Be!Z^<3orz8XebIA10hj^B0UTg_Ho-xnaU=TlG3eQ{9R$3=TnsfcSoYN314=+$ zwgt|-#Ql9*^Dx>VJRCIT5SEyfgB5ioLY&iKRswCi4$^|QP=^s+x&qMfaGn8ZgrRQl z)yEPIMrVVFFkyK^y_3xpPU{;SLOf!8+oOy83f>;oQhW)atBEC5Cn7ocLtSBF)L*k$#EAcsxsyyDM*@o!tf(72Bu;~= zLe&>R`WR5=5DGMDCC&qjNp*~NyUu_BOB`~)IM=GKxO@qi3#mgeuAe&x5d<+oiG79s zlGggEc^!j7P_q~^k~C{fBy%9zFo@a(8!@+W{vg=W-3Tr^ ziMMqyL>#^^*{NL+k4nZaYmzl_ea;k92uj;~wm|A3;ZJsx*kjzHa|I^gSv<-?c6&m4 z4e3dL|F%$)+&2_!Y+?#Xtoiv?;ysUt{qG8X`{EbH0@79T|HEj{TF|awSQxB1F9W0# z<~d~Gw&3}Q0qx>#fc)SN3j$XL2ZIzG3U<}*cMxIBOCo?X{Ld? zj^gd2(H??(&nGcH*AXQ%hw9@T(p!*7l+4`{CDY;%F)pT#l8NY;DIFbi!+^e2luIkT z)909cp2-)Gz!O+Xzr+Hmbe_3?#pG9*$ci6f?nWeE=BnYO>cGn8EA86{0CvP`bt!)t zYEws6-Vqyi{9*0`s3YijnL981;Z_sba1-+E5Lr!81H%>f%k0CiGWj(oBE0nRoDyjH zkfVNs6&8^AsW}D;(kr@6MMiNyvdF)IP&l!0(VywfY~~g3n&G}UQUMR8zX4@vkue@kOebT zd4eTGh#}@^1==Xiytv7ILI5|a;+ ztRrqwmn#Q=6P+Xn$!ki%MS+ip>>bf3=1Ow#;$)Cv>4CYDn$qQPxyPXW7+1;f0WA3t zQQyJ_#A8JLK2IY&qP~+f!vi{$%)Q9_hL5W+3wYZNODib8hO4!1qJrGUsFHOWBNf`5 z)bmVaDPx^N!^T9A$r}52D-Q){2=@5Dc=%N&T2D_XtW7{q0LjfmqIY9 z@ZA)Ta|qoIltF&#hCxUpWTrCRME#hK2(tAqgpc9r|0%mC1RuK)gok>m@EXN4o4BZu zMd%Co0ppn12Ou@k=m`&kI94~Jb!#8-D@E&$IK(b~sw=pH`Vmgy9u9Z~378dCV)sTl zroo4Kc278#T{wG991G$OUBZhcgk!N0v>P}QoCbqdLko?d-B;sSsLUM1vFM=P*Tk_r zlZpa)^Hn2*IG4W&I2T~@_moSj`HefSyF652AZJY}?4+-aT5ShOC-buvw#=s#u zS{6JEUgCzBs9VMN77pgm(1wn<{pZZx%Gv!obHdf+bb*|TYhDkhGYu_~cmSj<^Cj-n zv0rY#X~3fC0g<78#28N#bV~CgXqR(??9k7`rY^M(|Q9G8$KBGIm?3 zmk_c$)gQ83n2cMm@zHy%L`BR;#Bc1Ta_wsT1`LCN@xTy8*|(41pfZa_1*>72wHgl6YYvqzZc-oDa_Cp_0qc=zyPAlVQVNgDSCt=0`4uoC{ZJK{%I5 z@ErvzfPe*H`yy`w{^d^tUIiwHh<|xb%CP4)E&~@5wxe`AsAF~tk&-Xqq*w<2WdwG5 za!U`ajw0}&)a#6*F5N!a=a_%V2mA*VZ|E25bjHAwkbfB!{^ecZVP1{;QlOq72Y^-Ve`=t+9{ zv*>x2y(E?8W<}y83O)8W3e$>bH?ayfp9h~F?=vSOrt8q7M*O3?;yYyhY-D@s3y9o< zo2c|p;U~17Q`1H3wG{;$bLqX+Dq$VtKTn5f>JUsw9Hx#c@7YlJ(|d*#KZPT19svUb z%QX6>T_CZM<(vnjC4OnF3U0HpVRr%E!7hAUL+=_~t?e)2KGpw-4#lAwM3Flnviv-2 z-N~eZ#802l!K(2tDADV6*m(V7{vvdnUkSPA2ufe7@8)^-{YS%|=8Xt5lCq{j zP>cRosC(rg?w2qcfiGP=bip#~)~D$B?d==bauA(;K90{WulNg3O1$;HlMY`@6zR{S zzD^Bk?YV%KnCw2p9zV=_=jMW8IsF-r8 z9@HjKldn4iPW38TvxXGX5o=8L(7;fR9UTr$cA7clP&l@Y-Z&IY@8A zPgXiO1i-o8z%pml@zSnO8y|-Lfo4nHSyM?L*t@W2ga|*ifOi+48TdZPA7rRi{?^m^i+JYmkEat+WmrNj%COn?h+>7byz^`1+ug+ zK*mjcU2opoxK|jJD#0xL+MW*z5y7?k?h}2IvIA z0r6R|ArrtrJ~D|Ix6qS8zeUSGZ2On3GjO@12h-VGM3ZdIV7z73)^R&gPhyW0u{oA> zjYTm4{U#25?X>9XR{~A3!hR+kJXsJDgam$b!p}&<1TnZ|X6c*3THTA6oBMgAik8=*m8)z8CJC;Rnk$<^8*)s`)L6u<2!9I##?0g+~ zD`IGOgPjH$;c$x2B1k%fU-2D0Tfm@0k-1>2(*$r37jetJ%2f|abo@Go3*6}!Icn4k zhc4y>dX3#ic#5#m$FY4ndS2|qr;G*djQqE^^*(RNVr4Xj`#jtG$P$a-Htq&ukFRAgS62C6JoI3q=fnJ7Nft179;hefC^xbG-2xt z1YNOW2oMW(rYY#Fgd<|r$Q1O(0kT3g_0h8J=3zpckSIsgI&UYKq1R~gbG(!2GroOB zlUH3+Uwg>+f482*+K2zjtS9PI1ijBN`4>zEpzA*+5FP`#UrE5GGX5GB@V8tM0Arjl z2)%;aV%z7{H1WTRzW;OflR_0mY;Y+?f|_`q*@WhWAwNCPa$b%{y@1Q4hbP>!BnWcc)>5<0B zJaWoJ8smV8a(PnB;(TK<-(HNj7+(ww;xfHJZ_i>-K)cDS$-5k23l}cB&VqR+eLji3 z2=FG1pU4m?19tE?g!4IIP_Bsk!cA*)ykUb0*)#o0;2fB#AY?*M3_aSF zLVq0(7psKF*;hhEBFETMc$@1hcqwuOgIDTvkJf}B!=JEO9BDF4V3dk*nuKlx>=ksL z7vU0p7PjF~CNR|V9O)qD*l}E?et^CEK_)+h#81j*n?cL&W_Fx+Or-f9toF@U8s>1- zq`IN-n1JC`Jc4KpFY;ECxKqZKb)$WP{-&TeMSnxF&Zad)zLTF`vEphfs%^*d{;2Rt zj)s9QW<+p3)-&K(&GELc58(OdIY&c zG$n$mn4m!N7_T5K^gYPx{aIsL7WoYp%<(+2fZhouzQ3{nS}gl@CbA!kdMo`B?K>RB zjpKbgtK`L7VW9K4y>o0p8XX{Q9O7hai&v4(%D(kWx8eN-t@0UoHoG5(PO9}#6aU_eOni0t@A_X{po`QU1~FRKRgKd)o-!k0bu@4&ftFBIswk4AI!OTun*vO zZJL;C0QNVa`F>yzf}9ZSO?7ZTN4885xDQ% z-0y1#bC{rJ7u1}J)o}s0UEz5V&8#Ee}91cI+p`%q1Sl@HXAfXu?YYKMivmZ z0tlJuE*tM%K*764IGWeBfHYu<<%7_r*V7&P@6PL5N-Knrdj#~Zx|9%%n4qaQ;&$cJ zpeZ+Q0#yXKx`tKu$Q4&v_&P^uXQA2(2#&4p+j`~kMT!-)i4f-T;`@D$@7I`IaT$J; zV?pbFxD4s59Z;S~0Tl5ye7){Y6rm)-m$rh%*B`^G2toOGy&G-Y(#vCq@ zfN@|CL?L@8O9vqQ9}h>42){iRK2ii3BOCT@WD5Mn1O zTtl%pM*>;wqUm!dc)t!JlIc9Q_}Xirr~tX7$9IWh*Eg<9E+7CeoH|n-8bxu2d%9t; z!;VgI#Bktjf&38=4mtMZ*nm@mgUFmC9h351cg611QK*mguAyFhkDh>pqW` z;qD+;u?Eg!rzduQ;$GPJMw9Eh@hr@-e8VDLJ}G0)nGz&B)3J0g4eZs5pYdh5E$C;= zYBONO%R*#4fuT7j(X3?cM8k(-iT;AoPuzi@%ZoaRCm@5u_uN=Z&@Fo%d-H<$8j-!k z@Uy2?tK!Jpig)57_(SFbr|5c3>B){6zGckHmvEE7Kp0JLu)Ml-Yf-Qb#7*^7MM7-r z09h@94iTAeD*&y4m2bH#OMa$VJ731;Q&)Dc$i@|fK0qA=*jkdEqRPd^p;5qRL;nT| zwzSZ;a~y4Unf~6xTO{jEr-rvH>v$iUr2i4G6FnmN*PkN9-hvF?fo$xMeIpYVH!jN; z8c?dwzXE`vZwE{PFODX*mJIXJygP#d(SCGCLgA0yKI;eD!uK^(426)*MFEC`%>9aq_p% zRq9G&S)wn~adcKi424LSUVYJ>R<$!wBcqAxV#Re8A6@NW!ueNkWkS-aUsiAmIY0Gi z%UMB;+9z222_~Om@+UlJwp_jh{ajtOns~8e%~Sse&;17`d^?%EGE*zY|AL3V%;X?O zSnne!F!v@Vk20BI@*PYXOjejY#e}LSagPk+(jh8`oXHv?`zC+3pMjr-p=JS?@Nql3K=Vt%nTRrE*>b}Uwn|B z!Nog^+l#js_ZA<1*DmfVjuj8$?2|>OxKP|v%oMj3tHm3OrQ)lKxBou~ C4#UC# literal 0 HcmV?d00001 diff --git a/tools/modules/unet/mha_flash.py b/tools/modules/unet/mha_flash.py new file mode 100644 index 0000000..5edfe0e --- /dev/null +++ b/tools/modules/unet/mha_flash.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.cuda.amp as amp +import torch.nn.functional as F +import math +import os +import time +import numpy as np +import random + +# from flash_attn.flash_attention import FlashAttention +class FlashAttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(FlashAttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + if self.head_dim <= 128 and (self.head_dim % 8) == 0: + new_scale = math.pow(head_dim, -0.5) + self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + # self.apply(self._init_weight) + + + def _init_weight(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) + q = torch.cat([q, cq], dim=-1) + + qkv = torch.cat([q,k,v], dim=1) + origin_dtype = qkv.dtype + qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() + out, _ = self.flash_attn(qkv) + out.to(origin_dtype) + + if context is not None: + out = out[:, :-4, :, :] + out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) + + # output + x = self.proj(out) + return x + identity + +if __name__ == '__main__': + batch_size = 8 + flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() + + x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() + context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() + # context = None + flash_net.eval() + + with amp.autocast(enabled=True): + # warm up + for i in range(5): + y = flash_net(x, context) + torch.cuda.synchronize() + s1 = time.time() + for i in range(10): + y = flash_net(x, context) + torch.cuda.synchronize() + s2 = time.time() + + print(f'Average cost time {(s2-s1)*1000/10} ms') \ No newline at end of file diff --git a/tools/modules/unet/unet_unianimate.py b/tools/modules/unet/unet_unianimate.py new file mode 100644 index 0000000..097b52f --- /dev/null +++ b/tools/modules/unet/unet_unianimate.py @@ -0,0 +1,659 @@ +import math +import torch +# import xformers +# import xformers.ops +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from ....lib.rotary_embedding_torch import RotaryEmbedding +from fairscale.nn.checkpoint import checkpoint_wrapper + +from .util import * +# from .mha_flash import FlashAttentionBlock +from ....utils.registry_class import MODEL + + +USE_TEMPORAL_TRANSFORMER = True + + + +class PreNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + +class PreNormattention_qkv(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, q, k, v, **kwargs): + return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Attention_qkv(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_k = nn.Linear(dim, inner_dim, bias = False) + self.to_v = nn.Linear(dim, inner_dim, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, q, k, v): + b, n, _, h = *q.shape, self.heads + bk = k.shape[0] + + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h = h) + v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h = h) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class PostNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.norm(self.fn(x, **kwargs) + x) + + + + +class Transformer_v2(nn.Module): + def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)), + FeedForward(dim, mlp_dim, dropout = dropout_ffn), + ])) + def forward(self, x): + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + + + +@MODEL.register_class() +class UNetSD_UniAnimate(nn.Module): + + def __init__(self, + config=None, + in_dim=4, + dim=512, + y_dim=512, + context_dim=1024, + hist_dim = 156, + concat_dim = 8, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=1, + temporal_attention = True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition= False, + use_sim_mask = False, + misc_dropout = 0.5, + training=True, + inpainting=True, + p_all_zero=0.1, + p_all_keep=0.1, + zero_y = None, + black_image_feature = None, + adapter_transformer_layers = 1, + num_tokens=4, + **kwargs + ): + embed_dim = dim * 4 + num_heads=num_heads if num_heads else dim//32 + super(UNetSD_UniAnimate, self).__init__() + self.zero_y = zero_y + self.black_image_feature = black_image_feature + self.cfg = config + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.num_tokens = num_tokens + self.hist_dim = hist_dim + self.concat_dim = concat_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + + self.num_heads = num_heads + + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + self.training=training + self.inpainting = inpainting + self.video_compositions = self.cfg.video_compositions + self.misc_dropout = misc_dropout + self.p_all_zero = p_all_zero + self.p_all_keep = p_all_keep + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + self.resolution = config.resolution + + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + if 'image' in self.video_compositions: + self.pre_image_condition = nn.Sequential( + nn.Linear(self.context_dim, self.context_dim), + nn.SiLU(), + nn.Linear(self.context_dim, self.context_dim*self.num_tokens)) + + + if 'local_image' in self.video_compositions: + self.local_image_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.local_image_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'dwpose' in self.video_compositions: + self.dwpose_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.dwpose_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'randomref_pose' in self.video_compositions: + randomref_dim = 4 + self.randomref_pose2_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=2, padding=1)) + self.randomref_pose2_embedding_after = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'randomref' in self.video_compositions: + randomref_dim = 4 + self.randomref_embedding2 = nn.Sequential( + nn.Conv2d(randomref_dim, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=1, padding=1)) + self.randomref_embedding_after2 = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + ### Condition Dropout + self.misc_dropout = DropPath(misc_dropout) + + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + self.pre_image = nn.Sequential() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)]) + + #### need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) + else: + init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + + block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)]) + + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, + disable_self_attn=False, use_linear=True + ) + ) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) + else: + block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim + ) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,), + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, + disable_self_attn=False, use_linear=True + )]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + ) + ) + else: + self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + + self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + + block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024, + disable_self_attn=False, use_linear=True + ) + ) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset + ) + ) + else: + block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward(self, + x, + t, + y = None, + depth = None, + image = None, + motion = None, + local_image = None, + single_sketch = None, + masked = None, + canny = None, + sketch = None, + dwpose = None, + randomref = None, + histogram = None, + fps = None, + video_mask = None, + focus_present_mask = None, + prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + mask_last_frame_num = 0 # mask last frame num + ): + + + assert self.inpainting or masked is None, 'inpainting is not supported' + + batch, c, f, h, w= x.shape + frames = f + device = x.device + self.batch = batch + + #### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device) + else: + time_rel_pos_bias = None + + + # all-zero and all-keep masks + zero = torch.zeros(batch, dtype=torch.bool).to(x.device) + keep = torch.zeros(batch, dtype=torch.bool).to(x.device) + if self.training: + nzero = (torch.rand(batch) < self.p_all_zero).sum() + nkeep = (torch.rand(batch) < self.p_all_keep).sum() + index = torch.randperm(batch) + zero[index[0:nzero]] = True + keep[index[nzero:nzero + nkeep]] = True + assert not (zero & keep).any() + misc_dropout = partial(self.misc_dropout, zero = zero, keep = keep) + + + concat = x.new_zeros(batch, self.concat_dim, f, h, w) + + + # local_image_embedding (first frame) + if local_image is not None: + local_image = rearrange(local_image, 'b c f h w -> (b f) c h w') + local_image = self.local_image_embedding(local_image) + + h = local_image.shape[2] + local_image = self.local_image_embedding_after(rearrange(local_image, '(b f) c h w -> (b h w) f c', b = batch)) + local_image = rearrange(local_image, '(b h w) f c -> b c f h w', b = batch, h = h) + + concat = concat + misc_dropout(local_image) + + if dwpose is not None: + if 'randomref_pose' in self.video_compositions: + dwpose_random_ref = dwpose[:,:,:1].clone() + dwpose = dwpose[:,:,1:] + dwpose = rearrange(dwpose, 'b c f h w -> (b f) c h w') + dwpose = self.dwpose_embedding(dwpose) + + h = dwpose.shape[2] + dwpose = self.dwpose_embedding_after(rearrange(dwpose, '(b f) c h w -> (b h w) f c', b = batch)) + dwpose = rearrange(dwpose, '(b h w) f c -> b c f h w', b = batch, h = h) + concat = concat + misc_dropout(dwpose) + + randomref_b = x.new_zeros(batch, self.concat_dim+4, 1, h, w) + if randomref is not None: + randomref = rearrange(randomref[:,:,:1,], 'b c f h w -> (b f) c h w') + randomref = self.randomref_embedding2(randomref) + + h = randomref.shape[2] + randomref = self.randomref_embedding_after2(rearrange(randomref, '(b f) c h w -> (b h w) f c', b = batch)) + if 'randomref_pose' in self.video_compositions: + dwpose_random_ref = rearrange(dwpose_random_ref, 'b c f h w -> (b f) c h w') + dwpose_random_ref = self.randomref_pose2_embedding(dwpose_random_ref) + dwpose_random_ref = self.randomref_pose2_embedding_after(rearrange(dwpose_random_ref, '(b f) c h w -> (b h w) f c', b = batch)) + randomref = randomref + dwpose_random_ref + + randomref_a = rearrange(randomref, '(b h w) f c -> b c f h w', b = batch, h = h) + randomref_b = randomref_b + randomref_a + + + x = torch.cat([randomref_b, torch.cat([x, concat], dim=1)], dim=2) + x = rearrange(x, 'b c f h w -> (b f) c h w') + x = self.pre_image(x) + x = rearrange(x, '(b f) c h w -> b c f h w', b = batch) + + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + + context = x.new_zeros(batch, 0, self.context_dim) + + + if image is not None: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + + image_context = misc_dropout(self.pre_image_condition(image).view(-1, self.num_tokens, self.context_dim)) # torch.cat([y[:,:-1,:], self.pre_image_condition(y[:,-1:,:]) ], dim=1) + context = torch.cat([context, image_context], dim=1) + else: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + image_context = torch.zeros_like(self.zero_y.repeat(batch, 1, 1))[:,:self.num_tokens] + context = torch.cat([context, image_context], dim=1) + + # repeat f times for spatial e and context + e = e.repeat_interleave(repeats=f+1, dim=0) + context = context.repeat_interleave(repeats=f+1, dim=0) + + + + ## always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b = batch) + return x[:,:,1:] + + def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference) + else: + x = module(x) + return x + + + diff --git a/tools/modules/unet/util.py b/tools/modules/unet/util.py new file mode 100644 index 0000000..1d6c6e6 --- /dev/null +++ b/tools/modules/unet/util.py @@ -0,0 +1,1741 @@ +import math +import torch +import xformers +# # import open_clip +# import xformers.ops +import torch.nn as nn +from torch import einsum +from einops import rearrange +from functools import partial +import torch.nn.functional as F +import torch.nn.init as init +from ....lib.rotary_embedding_torch import RotaryEmbedding +from fairscale.nn.checkpoint import checkpoint_wrapper + +# from .mha_flash import FlashAttentionBlock +# from utils.registry_class import MODEL + + +### load all keys started with prefix and replace them with new_prefix +def load_Block(state, prefix, new_prefix=None): + if new_prefix is None: + new_prefix = prefix + + state_dict = {} + state = {key:value for key,value in state.items() if prefix in key} + for key,value in state.items(): + new_key = key.replace(prefix, new_prefix) + state_dict[new_key]=value + return state_dict + + +def load_2d_pretrained_state_dict(state,cfg): + + new_state_dict = {} + + dim = cfg.unet_dim + num_res_blocks = cfg.unet_res_blocks + temporal_attention = cfg.temporal_attention + temporal_conv = cfg.temporal_conv + dim_mult = cfg.unet_dim_mult + attn_scales = cfg.unet_attn_scales + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + #embeddings + state_dict = load_Block(state,prefix=f'time_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state,prefix=f'y_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state,prefix=f'context_embedding') + new_state_dict.update(state_dict) + + encoder_idx = 0 + ### init block + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + encoder_idx += 1 + + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + idx = 0 + idx_ = 0 + # residual (+attention) blocks + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ = 2 + + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + # if temporal_attention: + # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + in_dim = out_dim + encoder_idx += 1 + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + # downsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout) + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + + shortcut_dims.append(out_dim) + scale /= 2.0 + encoder_idx += 1 + + # middle + # self.middle = nn.ModuleList([ + # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'), + # TemporalConvBlock(out_dim), + # AttentionBlock(out_dim, context_dim, num_heads, head_dim)]) + # if temporal_attention: + # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + # elif temporal_conv: + # self.middle.append(TemporalConvBlock(out_dim,dropout=dropout)) + # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none')) + # self.middle.append(TemporalConvBlock(out_dim)) + + + # middle + middle_idx = 0 + # self.middle = nn.ModuleList([ + # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout), + # AttentionBlock(out_dim, context_dim, num_heads, head_dim)]) + state_dict = load_Block(state,prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + state_dict = load_Block(state,prefix=f'middle.1',new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 1 + + for _ in range(cfg.temporal_attn_times): + # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + middle_idx += 1 + + # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)) + state_dict = load_Block(state,prefix=f'middle.2',new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + + decoder_idx = 0 + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + idx = 0 + idx_ = 0 + # residual (+attention) blocks + # block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)]) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 1 + for _ in range(cfg.temporal_attn_times): + # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + idx_ +=1 + + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + + # upsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + + scale *= 2.0 + # block.append(upsample) + # self.decoder.append(block) + decoder_idx += 1 + + # head + # self.head = nn.Sequential( + # nn.GroupNorm(32, out_dim), + # nn.SiLU(), + # nn.Conv3d(out_dim, self.out_dim, (1,3,3), padding=(0,1,1))) + state_dict = load_Block(state,prefix=f'head') + new_state_dict.update(state_dict) + + return new_state_dict + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, + torch.pow(10000, -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device = device, dtype = torch.bool) + elif prob == 0: + return torch.zeros(shape, device = device, dtype = torch.bool) + else: + mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < prob + ### aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0]=False + return mask + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, max_bs=4096, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.max_bs = max_bs + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + if q.shape[0] > self.max_bs: + q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0) + k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0) + v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0) + out_list = [] + for q_1, k_1, v_1 in zip(q_list, k_list, v_list): + out = xformers.ops.memory_efficient_attention( + q_1, k_1, v_1, attn_bias=None, op=self.attention_op) + out_list.append(out) + out = torch.cat(out_list, dim=0) + else: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads = 8, + num_buckets = 32, + max_distance = 128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class SpatialTransformerWithAdapter(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, + adapter_list=[], adapter_position_list=['', 'parallel', ''], + adapter_hidden_dim=None): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, + adapter_list=adapter_list, adapter_position_list=adapter_position_list, + adapter_hidden_dim=adapter_hidden_dim) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class Adapter(nn.Module): + def __init__(self, in_dim, hidden_dim, condition_dim=None): + super().__init__() + self.down_linear = nn.Linear(in_dim, hidden_dim) + self.up_linear = nn.Linear(hidden_dim, in_dim) + self.condition_dim = condition_dim + if condition_dim is not None: + self.condition_linear = nn.Linear(condition_dim, in_dim) + + init.zeros_(self.up_linear.weight) + init.zeros_(self.up_linear.bias) + + def forward(self, x, condition=None, condition_lam=1): + x_in = x + if self.condition_dim is not None and condition is not None: + x = x + condition_lam * self.condition_linear(condition) + x = self.down_linear(x) + x = F.gelu(x) + x = self.up_linear(x) + x += x_in + return x + + +class MemoryEfficientCrossAttention_attemask(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + + +class BasicTransformerBlock_attemask(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention_attemask + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerBlockWithAdapter(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, + adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], adapter_hidden_dim=None, adapter_condition_dim=None + ): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + # adapter + self.adapter_list = adapter_list + self.adapter_position_list = adapter_position_list + hidden_dim = dim//2 if not adapter_hidden_dim else adapter_hidden_dim + if "self_attention" in adapter_list: + self.attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + if "cross_attention" in adapter_list: + self.cross_attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + if "feedforward" in adapter_list: + self.ff_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + + + def forward_(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + return checkpoint(self._forward, (x, context, adapter_condition, adapter_condition_lam), self.parameters(), self.checkpoint) + + def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + if "self_attention" in self.adapter_list: + if self.adapter_position_list[0] == 'parallel': + # parallel + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + self.attn_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[0] == 'serial': + # serial + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + + if "cross_attention" in self.adapter_list: + if self.adapter_position_list[1] == 'parallel': + # parallel + x = self.attn2(self.norm2(x), context=context) + self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[1] == 'serial': + x = self.attn2(self.norm2(x), context=context) + x + x = self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.attn2(self.norm2(x), context=context) + x + + if "feedforward" in self.adapter_list: + if self.adapter_position_list[2] == 'parallel': + x = self.ff(self.norm3(x)) + self.ff_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[2] == 'serial': + x = self.ff(self.norm3(x)) + x + x = self.ff_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.ff(self.norm3(x)) + x + + return x + +class BasicTransformerBlock(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class UpsampleSR600(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + # TODO: to match input_blocks, remove elements of two sides + x = x[..., 1:-1, :] + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) + # self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + # h = self.temopral_conv_2(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d(x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True, + mode='none', dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + def __init__( + self, + dim, + heads = 4, + dim_head = 32, + rotary_emb = None, + use_image_dataset = False, + use_sim_mask = False + ): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + # heads = dim // dim_head if dim_head else heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3)#, bias = False) + self.to_out = nn.Linear(hidden_dim, dim)#, bias = False) + + # nn.init.zeros_(self.to_out.weight) + # nn.init.zeros_(self.to_out.bias) + + def forward( + self, + x, + pos_bias = None, + focus_present_mask = None, + video_mask = None + ): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim = -1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h = height) + + return out + identity + + # split out heads + # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads) + # shape [b (hw) h n c/h], n=f + q= rearrange(qkv[0], '... n (h d) -> ... h n d', h = self.heads) + k= rearrange(qkv[1], '... n (h d) -> ... h n d', h = self.heads) + v= rearrange(qkv[2], '... n (h d) -> ... h n d', h = self.heads) + + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + # shape [b (hw) h n n], n=f + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + # print(sim.shape,pos_bias.shape) + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + #video_mask: [B, n] + mask = video_mask[:, None, :] * video_mask[:, :, None] # [b,n,n] + mask = mask.unsqueeze(1).unsqueeze(1) #[b,1,1,n,n] + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool) + attend_self_mask = torch.eye(n, device = device, dtype = torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril(torch.ones((n, n), device = device, dtype = torch.bool), diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + attn = sim.softmax(dim = -1) + + # aggregate values + + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h = height) + + if self.use_image_dataset: + out = identity + 0*out + else: + out = identity + out + return out + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalTransformerWithAdapter(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False, + adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], + adapter_hidden_dim=None, adapter_condition_dim=None): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint, adapter_list=adapter_list, adapter_position_list=adapter_position_list, + adapter_hidden_dim=adapter_hidden_dim, adapter_condition_dim=adapter_condition_dim) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if adapter_condition is not None: + b_cond, f_cond, c_cond = adapter_condition.shape + adapter_condition = adapter_condition.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1, 1) + adapter_condition = adapter_condition.reshape(b_cond*h*w, f_cond, c_cond) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x, adapter_condition=adapter_condition, adapter_condition_lam=adapter_condition_lam) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class PreNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + +class TransformerV2(nn.Module): + def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)), + FeedForward(dim, mlp_dim, dropout = dropout_ffn), + ])) + def forward(self, x): + # if self.depth + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + +class TemporalTransformer_attemask(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock_attemask(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + +class TemporalAttentionMultiBlock(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList( + [TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times)] + ) + + def forward( + self, + x, + pos_bias = None, + focus_present_mask = None, + video_mask = None + ): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, in_dim, out_dim=None, dropout=0.0,use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim#int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv1[-1].weight) + # nn.init.zeros_(self.conv1[-1].bias) + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0*x + else: + x = identity + x + return x + +class TemporalConvBlock(nn.Module): + + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset= False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim#int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv1[-1].weight) + # nn.init.zeros_(self.conv1[-1].bias) + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0*x + else: + x = identity + x + return x + +class TemporalConvBlock_v2(nn.Module): + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + + From 626e7afc02230297b6f553675ea1c32c29971314 Mon Sep 17 00:00:00 2001 From: Isi <86603298+Isi-dev@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:57:50 +0100 Subject: [PATCH 5/5] Add files via upload --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 6755761..ba7639a 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,9 @@ conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda= ``` If not installed, then: + + +``` pip install opencv-python pip install pytorch_lightning pip install lightning_utilities #if not installed @@ -47,6 +50,8 @@ pip install einops pip install args pip install modelscope +``` + Download the required models (Around 14GB) after installing modelscope :