From 00a5b59d93eea836808d5e69695b1f1c1329e36f Mon Sep 17 00:00:00 2001 From: Isi <86603298+Isi-dev@users.noreply.github.com> Date: Thu, 1 Aug 2024 21:39:54 +0100 Subject: [PATCH] Add files via upload --- tools/__init__.py | 3 + tools/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 256 bytes tools/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 200 bytes tools/datasets/__init__.py | 2 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 254 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 198 bytes .../__pycache__/image_dataset.cpython-310.pyc | Bin 0 -> 2897 bytes .../__pycache__/image_dataset.cpython-39.pyc | Bin 0 -> 2846 bytes .../__pycache__/video_dataset.cpython-310.pyc | Bin 0 -> 3464 bytes .../__pycache__/video_dataset.cpython-39.pyc | Bin 0 -> 3402 bytes tools/datasets/image_dataset.py | 86 + tools/datasets/video_dataset.py | 118 ++ tools/inferences/__init__.py | 2 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 293 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 237 bytes ...erence_unianimate_entrance.cpython-310.pyc | Bin 0 -> 12891 bytes ...ference_unianimate_entrance.cpython-39.pyc | Bin 0 -> 12529 bytes ...e_unianimate_long_entrance.cpython-310.pyc | Bin 0 -> 13012 bytes ...ce_unianimate_long_entrance.cpython-39.pyc | Bin 0 -> 13136 bytes .../inference_unianimate_entrance.py | 546 ++++++ .../inference_unianimate_long_entrance.py | 508 +++++ tools/modules/__init__.py | 7 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 438 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 382 bytes .../__pycache__/autoencoder.cpython-310.pyc | Bin 0 -> 16687 bytes .../__pycache__/autoencoder.cpython-39.pyc | Bin 0 -> 17098 bytes .../__pycache__/clip_embedder.cpython-310.pyc | Bin 0 -> 6728 bytes .../__pycache__/clip_embedder.cpython-39.pyc | Bin 0 -> 6645 bytes .../__pycache__/config.cpython-310.pyc | Bin 0 -> 3845 bytes .../modules/__pycache__/config.cpython-39.pyc | Bin 0 -> 3696 bytes .../embedding_manager.cpython-310.pyc | Bin 0 -> 5244 bytes .../embedding_manager.cpython-39.pyc | Bin 0 -> 5161 bytes tools/modules/autoencoder.py | 698 +++++++ tools/modules/clip_embedder.py | 241 +++ tools/modules/config.py | 206 ++ tools/modules/diffusions/__init__.py | 1 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 240 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 184 bytes .../diffusion_ddim.cpython-310.pyc | Bin 0 -> 24748 bytes .../__pycache__/diffusion_ddim.cpython-39.pyc | Bin 0 -> 28560 bytes .../__pycache__/losses.cpython-310.pyc | Bin 0 -> 1337 bytes .../__pycache__/losses.cpython-39.pyc | Bin 0 -> 1279 bytes .../__pycache__/schedules.cpython-310.pyc | Bin 0 -> 4726 bytes .../__pycache__/schedules.cpython-39.pyc | Bin 0 -> 4653 bytes tools/modules/diffusions/diffusion_ddim.py | 1121 +++++++++++ tools/modules/diffusions/diffusion_gauss.py | 498 +++++ tools/modules/diffusions/losses.py | 28 + tools/modules/diffusions/schedules.py | 166 ++ tools/modules/embedding_manager.py | 179 ++ tools/modules/unet/__init__.py | 2 + .../unet/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 235 bytes .../unet/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 179 bytes .../unet_unianimate.cpython-310.pyc | Bin 0 -> 15790 bytes .../unet_unianimate.cpython-39.pyc | Bin 0 -> 15915 bytes .../unet/__pycache__/util.cpython-310.pyc | Bin 0 -> 42320 bytes .../unet/__pycache__/util.cpython-39.pyc | Bin 0 -> 44250 bytes tools/modules/unet/mha_flash.py | 103 + tools/modules/unet/unet_unianimate.py | 659 +++++++ tools/modules/unet/util.py | 1741 +++++++++++++++++ 59 files changed, 6915 insertions(+) create mode 100644 tools/__init__.py create mode 100644 tools/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/datasets/__init__.py create mode 100644 tools/datasets/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/datasets/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/datasets/__pycache__/image_dataset.cpython-310.pyc create mode 100644 tools/datasets/__pycache__/image_dataset.cpython-39.pyc create mode 100644 tools/datasets/__pycache__/video_dataset.cpython-310.pyc create mode 100644 tools/datasets/__pycache__/video_dataset.cpython-39.pyc create mode 100644 tools/datasets/image_dataset.py create mode 100644 tools/datasets/video_dataset.py create mode 100644 tools/inferences/__init__.py create mode 100644 tools/inferences/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/inferences/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc create mode 100644 tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-39.pyc create mode 100644 tools/inferences/inference_unianimate_entrance.py create mode 100644 tools/inferences/inference_unianimate_long_entrance.py create mode 100644 tools/modules/__init__.py create mode 100644 tools/modules/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/modules/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/modules/__pycache__/autoencoder.cpython-310.pyc create mode 100644 tools/modules/__pycache__/autoencoder.cpython-39.pyc create mode 100644 tools/modules/__pycache__/clip_embedder.cpython-310.pyc create mode 100644 tools/modules/__pycache__/clip_embedder.cpython-39.pyc create mode 100644 tools/modules/__pycache__/config.cpython-310.pyc create mode 100644 tools/modules/__pycache__/config.cpython-39.pyc create mode 100644 tools/modules/__pycache__/embedding_manager.cpython-310.pyc create mode 100644 tools/modules/__pycache__/embedding_manager.cpython-39.pyc create mode 100644 tools/modules/autoencoder.py create mode 100644 tools/modules/clip_embedder.py create mode 100644 tools/modules/config.py create mode 100644 tools/modules/diffusions/__init__.py create mode 100644 tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc create mode 100644 tools/modules/diffusions/__pycache__/losses.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/losses.cpython-39.pyc create mode 100644 tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc create mode 100644 tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc create mode 100644 tools/modules/diffusions/diffusion_ddim.py create mode 100644 tools/modules/diffusions/diffusion_gauss.py create mode 100644 tools/modules/diffusions/losses.py create mode 100644 tools/modules/diffusions/schedules.py create mode 100644 tools/modules/embedding_manager.py create mode 100644 tools/modules/unet/__init__.py create mode 100644 tools/modules/unet/__pycache__/__init__.cpython-310.pyc create mode 100644 tools/modules/unet/__pycache__/__init__.cpython-39.pyc create mode 100644 tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc create mode 100644 tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc create mode 100644 tools/modules/unet/__pycache__/util.cpython-310.pyc create mode 100644 tools/modules/unet/__pycache__/util.cpython-39.pyc create mode 100644 tools/modules/unet/mha_flash.py create mode 100644 tools/modules/unet/unet_unianimate.py create mode 100644 tools/modules/unet/util.py diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..33ef13c --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,3 @@ +from .datasets import * +from .modules import * +from .inferences import * diff --git a/tools/__pycache__/__init__.cpython-310.pyc b/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74bb8278b4d5ad8bfbcfbf761c96156e10650dc3 GIT binary patch literal 256 zcmd1j<>g`kf|lOtX+A*uF^GcyfROD#&xOHM6b$xy@!R1GG6B|2Nh zgche36~|;2XJ+NcIOpf4Rfc-TmuKds!Q&Njz zg7KkwnT~mxxrrsIF(vu=ImI#Y@tJva4_;P0sy>* BM;8D9 literal 0 HcmV?d00001 diff --git a/tools/__pycache__/__init__.cpython-39.pyc b/tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f55644d5d92c88f526b3325d5918d3feaaf5e5e GIT binary patch literal 200 zcmYe~<>g`kf|lOtX+A*uF^GcyfROD#&xOHM6b$xy@!R1GG6X**lR zgche36~|;2XJ+Ncgyv;B=4Iw4mZZj%Q(0su{mF-ZUb literal 0 HcmV?d00001 diff --git a/tools/datasets/__init__.py b/tools/datasets/__init__.py new file mode 100644 index 0000000..f1b217f --- /dev/null +++ b/tools/datasets/__init__.py @@ -0,0 +1,2 @@ +from .image_dataset import * +from .video_dataset import * diff --git a/tools/datasets/__pycache__/__init__.cpython-310.pyc b/tools/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be3d985797983558a345dab5b5d3699f7a547167 GIT binary patch literal 254 zcmYjLy9&ZU5WGt)L`>&T#KI2{5yi^JM$^c_61|I@B)1`ZD)zS4euKYcYvnIkxjq68 z>A}i^W$ZYcLma_>dgM@+P~pI9NuuNH+#_ z>Qgd)8SSk$Slg?pLO%Y`b35s)giJ-slHH=3V38e@DIBcRcpxaT^64r!b)eQe@>qaF sZKz&C;_TRQr#H$Y=A?ejqyVh59-vnqZ_Po0dY7Bw`&Vi(HAP##0GRAY`2YX_ literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/__init__.cpython-39.pyc b/tools/datasets/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6726b131c60fac9c8a6d9607115635f43748edb GIT binary patch literal 198 zcmYe~<>g`kf|lOtX*NLmF^Gc@X literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/image_dataset.cpython-310.pyc b/tools/datasets/__pycache__/image_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77f0b69b97e1f5c03a02280cadce19a15302555f GIT binary patch literal 2897 zcmZuzTW{RP6`mOmughIavMo!AkqSs+bkla*=F%pqV8pS5#6X|~aSS611q8Fh-B9Lb zJ)|sYGYkyKMccRbAM8FA{u%uXMW5!WKwsj25cnbJIm2B?HcVp9TzTfqneTjuWw#p; zc>eY8e~ur8g!~UXXCE6nA3&-84nhz?Gcs&zI;5u0hRoE~&@we2a#Pzw+tkj`g_>nv z?hQSP-&&cU2gAU$dDh9pVQAWR*3Emv9<&|dX6N#?;Tk2$+I)RMJ|@Bw{zD@Cq<3na z(&2`nw@5#D3|5i7KGl3wPEVk=Z|3nNfrk6xNu5+TKf~sYxQ?r&er&;F=zBM=eR=KH zM_=9=;n!y$2Ra`>sjh-Z2%Zp5M+FmX!h|Kbu!TK#(7uNh&TDLf^9G{xkidCE(S_O( zHrUqFp|NGTdHMtF{`kM3x6jfC6xsv@6tXw+L9U)IG&qzFlEgwxWuivK!>m+{!?iS$wNTgEY(Vmpo-`~GoC9>L2tMs7Uzh36!liN3=qqGp^Q58)~ zS;yaG$!cPMbXe789u=iXs?}h0yGXAUX&%?feqEMXwU3*?9fMW`K(y=*rYAbs&C;qK zmHG7QYj_Gg(iWNZ&)&`;iqayjqiC?3m7_STu0mt=RX@-{9v??z8Rto*1DRB1c37un zp@Uk+MKvyEUg_R_P;h3>^v}#&vvjI?IZX=f%On;fi>t=sL$y*w0$|-pgp=lLmJT8#TWYR8PFhr(973R74l0BpgEwY7_xK^ z;9+G+VeINt`3snDjX`*{#<)=gE28L^r~oy=z89X`*DM&}|P`XVZ zCRk%BYfc4SKqPFFhR;Gp=1ehVE!aM9?8a%_#=Ap`PdKTqlW!<#d`f1!HJ{tUnmZ74 z{(@n|2;O*5c8&Nw&&D?2avjyLWd@)hu757MeTbIL#6<7z}of()!z-M9NaW zzw=$Ai9~9X$SM8McphSPnx(Zq7s2;Y3qHh{_CGovB~ycLH!CNTw3t9B7voYpi1WDC zc3sL5;Aob}vTC7cK#2J32j)Vx4bhw&%MBRP{ur!kvFE1wB)XfNU<8cYfYLt3jX8i0 z?x!LtP4H@W6t`B;A~f0{I{FptP+bPm1Bf!ZWq`^lr!T?#2K64WCp~wz^&EJf zem(+Le){yFg)TAx`UD&d^Tnm+i$=56bRaO+DcDVo1utIn6{$BtE9?m=rmMQD zqkLtLIe^yRlUKm8k1>^K>x+*+ob>^@$bviO%3A<0@%tQje}=BX{GfkHA}?rGR+<|| z;|FD0Xd9?_l1RKz`7$oxIiw80QCZwiWUZZe3fwE?8@K>rBY%zJHz)=uev878sG(C# z7KTY%PBjomy6N{?&O&U<--8$i2EB%v<(oMA2N0S@5Cxcwsc5k--^PB} zehC?%Dy4i2*BKUkN`q8A-Ud%LYat|H+YrtydeUS7T6-{tGDP5OVEA;9(k literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/image_dataset.cpython-39.pyc b/tools/datasets/__pycache__/image_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7ce1cb7c2a2edf56f6cbf1b737da6f7c12a0898 GIT binary patch literal 2846 zcmZ`*OOM>f5${iE2m<(jO$IUI_|`uC9Kty1M$SuNn7x zA;WX>=f6(=-ev4x)HwS%XuOM3K7o)-@{Enznvb|?i;*z3HL^@?k8D#rBgfS4$U`kM zf98*TPTyKtFbhYa(e11=>yElcce38BKk6IZ&Cbm>MjM}!NuBJz<;l>)Zdw{O4y&fceGqRPFrO!teuPxJJ4p3dSb*{g~oEB9zR z4hrLRiqX0`oFD0MCris}T+HUTUWcjwps>Z@Y?Oyll;&v_MZ=w}7{^(83(Dy?gHVUF z_%NENc$SnpR7qK65397uby%r5FDHeXmAd~32F{F`_L*^OkUlwL!Bl&}{!S0@PKKJ=Ssg`V}_6!22j&zPS9-xIK6H1Ox)W z`zR$*M1{lO=f{+Fd>6-TWl3)Q8gTUkjJM9vF}KcmQG^*$bb}0#CB_$*URc*m7@I+w zfQMG1VwU|U6f6p~U1h+wXsp^=a>-Xf^i|f_i_MxXMJ;M;CHCybY23zZ{QE3_VzbIR z`UhuCz}e!kvX_pumM(H9SP4SDv>PAAuDqo$kzM_qn6$N>5{mW)+np9gYG`A*k znk8~L9Bw^Mt1q?^r3&@V)|ZWDPr)KgQFI1f6M5v^JWDHmF2eUw3qHb_4n8;>Cv$^t zFDs_gG@l|R^GTsy!g*Y2rz+GKI9epCC|l?m5EA|dp*c|Pq`6EE)p>O3U;?jN?0M;I z8a+skC<7*KLg|3=#%w@`k5ZWwCUvzpj$1Ey@*+Gs5jkxT9sLL^7=mE{XrBwdX@F{T zn_t8G7GG?>0Gwx^jwvRmPag_VrvcAz5nK?YRf{EVtrZ7=0FjFCi=_q1v{wLBWlM3) zL7j`nUEHi3Km_n;yxQCMr1ij3AJqg9%k){xp!Ntn_rcx=f4k2PZUAQC(GoDjy|#Fx z>MXm`UiJVpzlGTv4E|UC!Hu&zfPp>&#=(uMzdQ#xZ5TM6UxV$u^Z-M<2^PPpF4Va1 zr1!vn!j)ZHG}3^!mlx|0es!Q-LIM!6cWY32JOB4W`V+Ck@@r(@*uK2Bec9Nynhp}> zI>%l(`~P!3W4YL7)fHIAy0|`ib+7K!LG4UH{DWXuT}Qy5(lw#2Uw`=CVgNJ~Gj16e zw4h#N`UOBMVg>_elLy1WD{4SyL{Vznu$;YLq`7uL&(lO1@PCgc*t0?=z*dnzN>ruY zcn&_6icqB9B=KVsLlQqBu}#7dtD#uBJ`B6IJZoT26g23!JVtm|Hz7u$L9}6O^;7Eo z1%&1iGJ~!s^-JpeIn4}asciADen#zX`z3CLvQX+RnrE2t89k~u(DybiRosi!J z_QPi{k@f51RrK|_Z+fqwk6)yiDWyr2; literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/video_dataset.cpython-310.pyc b/tools/datasets/__pycache__/video_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c0008633bfee0df40a629989d67d77ca21793d GIT binary patch literal 3464 zcmZ`*TW=gm6|SnjOn1-Kj;~2JYqDsiS%e9J)k5$_>vb+fERL}gq8)kFbf$V{+wSR} zRJZMTwt5xGK?@Quh(9oXKq$X~e}NDW^@N0w@E-sv4~TQB$FUQIZuM8!u2Yxqe5cZS z-6Qb4{L7zuf14uYA6Pj5SWx%`zU+4ZoNyYGPMOjU)n(RUIyE|mPR)*~Q>$a?)b7|i zt#lknS?mtnj!SV2BlZT>PF3g4xHhPF>X5g%9XAHePLq&*<~B@dO4WWD z@i09#{x2B(1&f|hLQV+43wH8%up(*zCoKdAVEslBcU5#rfWC0A7-J*wxcZCNw?S1LGSS4 zJ^vs|czTfe!&KzK<2XER*zWFSc{=cul!w`ArT;LAR+DHDU|PO~{=7)vhad5Ipq{ zQo%FjVYVlNK?r#fW@)^aM`@xO`|$s{vX=#eVI2CsA@s{dkYv4746?G{xpK1~=6){{ z86b|Tj8QaHW;zTL<%%%iag=~5JA9gl30EF0x*z&Dxpq^TSs3?}t9KZyRQ+_bbq_HQ zYZZ$K7ZDZ^E+Je-cn9GMfU=U|1xbi|26qpcvt|24D7907-U$Sc?Q`%nCyqW6J3bF*;ZGxN&W z;owd?F1HuamWv}F8-?|%#hbKvlNK9Wpl9_h|CvSQ5&0aXt#H3~6Ocbl59k2_m$kuR z9clI~aI4DtOG8R2(6L7Ee^o5O(Jz%$TIvN+9P;Jm<)wot|8gl5A{955zA6+EsKqDv zPrFtgkd+PNC|6A%sdL76<=%eU4TqYIl{oG9qofaBnDkO*Blm(_S$QhDaIB+Hq}dz9 zXnW$@U_@Eqh~ZO#POjV@%&MeI850RF=30a}`l(*iwnN+$#IS(3^zD$C3^;PS!^Q^aBao(GxyjiP1)cmuVU6~43rGS4DQszvE4j7fK)y1`yO4Ks4|R8^ zdI?q2Mcv)0&5*)_qXmM@lU*V{=0;JKPEku5d41gA2DU9g=E%ed0*CL6~^vA*??S)TMJ~@5=n9q>lcW0 z=OCKyGVu+ppe0){^C_^^%BRN{WDDlql8x~U=0M+n@d$!ZatT~y_VAD3B2(ZZ&3tY= z&+YL7L>u@M`%Q>E_4U+~X&^GBNaTERd=b~oaThLeOFPXaZh+I6#Wd7i7Jr8IUr3z% zo&1Vy98vD*y}B^EHYppsrub+6u58E)i@1w3;KWOE24*=6xo=?O+ zu@#ck(Ym&Lbz=Fdwrmu$;PkK#eE)x+&k0zan^>Kf^Kw>BgF~9&klwoZ4p8OF} zSlDU#d71}tsULo_H&Aw#3kcesW;fjr#kq!RmIvZgUR;>(M06@c*w;K1%G2VhL!B}N zY^J3v#K&m=CjgxaN>#+4LIjT{zl$3jU~K807!Ex;M=vt-nN2U_h-`HE zb$NST(_g!UVGO|i*Aj=}GRN3OkMLSd*V{~NL7ff}ZCW15>kCsA--k2w{S5l)dk{qD znb*yi9omk-OR8`YP*p`8p#iYIey3f*D?kF572j8$?+;SG2VufjHQ#@-7sO?Yz}uF< zebgH&e9WCq=xnf49VmWW-*8bXRr5f=Ze#;eJHSBds?hQB5ngHb0@7u^3?GSsY9#>TzRN>6C%Ft`jCw>VqI8 s#zB?|8%J%Fl|aw366f-^#VBW`+>M)fSUR>j!~u$&y+|W`ZDWr94}e``0ssI2 literal 0 HcmV?d00001 diff --git a/tools/datasets/__pycache__/video_dataset.cpython-39.pyc b/tools/datasets/__pycache__/video_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb6e95f5e8a0b7d2189a9ea436324b30bf34686 GIT binary patch literal 3402 zcmZ`*OK%*<5$^7J&pt^}d`Pmw*o1*qkd30)NvtT0B~lU$C={WG1uv)!#@n;&A!lcn zJwu6H_ZR`lAPJBFiZ8h?b<^Jnf}8``r#>0T#XKj-!KYNsEG1GwS8@$ zcBVMJPg?d7tca|&sM5R1a1V0hqk-QKpkS@~nV$w36z$dP53b+8`QUyRd!PJlsC)=1 zZ36@ZcLO_Va+^B`6!zls$^n7BRCx_@gV%Wja+8~&drDP54tbDF^c;Pw5%PkCzomp6 z5r7x$=o^?-Q~-{e2&irCRu=`wUjqs{B$TBRKx&eBM4Et9>VQldfSR-bb!h_{(g8H3 z3uyI8Ud?PtCEI7KABGDp@p-DhL>kmU8>lOE-aYX`QhuabS5*JX8Fk@Z7L z$@a3W!aP+@U`zU%qM{3HfB7*e;rh-%N@oE2M3TqVoc@7G+Gl@Z+xmU-f(YD4iz(K= ztn@Go1~1^$Ib49pLs2w-_rc25jfZI<(v2_;H+b(6tx`4Z^)HzYD@ojm-u=Y?^Yd0zWklyvWD^LoK!A=c*Q7S4oE`!R7f7y; z<|P?3$)tA3HuT)c&D_fEMF4YSV{10fP|b@B6Dkt6U61=v=!_jEqfEPg7lcT?unIHUB6rh4NCH$eV62^c)09> zV&0EHnj$F)aHUQeoi8|}Q=OiJ|2;aIe~mjYhj#H)jz63NJzznfRUnpRV+b>c6zHSD zb<<5Y*1*RBVRDw~obA(9@{B%S&W+K3G6N_A^vF$Vj;+ktw4rQG$_|w6%ti0HSvils z&q43GS)C!d3l|IYm?K+6NUr4-Y3J4WLRK5sxrS}?5US=0=&ZD|bi;A{1tfmAlc-(-`M9U}hmjBd}!nykRw|J;5Kr7G9w$(A9$Ic}LJ(wT(- zxW&YC*h5n`VeM0(t(i@a&&VdMyeaGB87zUHe*Or83XJVNBHK$~C}*G1y-&eZrodD# zWwYZsZj9$4dcX|We?sV~FQ&GaPL_e(AVYF*FJ%kkb9fdUqRr-ct{0|so@-!AdOi)Z z3(u(d5>Dn!+|1s}F3S2oe`kLy}r%9C+qUe0-o>;nD?Tbf%Tt-()TgW zCPIFR@`5ohSk6k12Ib1=Z{q2eNw+roBeiq!pyA*az$^Y;335nTd0-9$hn&;%4 zJS(TcTp=p8oHg+SAk`5*UkaNQ80AMTFkJM@iv{0H`<YGUzD+2;UAAElI$(6}-UZuBt27*S+<7ch|jJ>-W{ntJ=NmzrN|M z+B9Fy^%8 z3$IeR6Y+Cw{{Z2a0E&95a1^lJj*fw{)AX_=^P|$ac=66anQ0~Z z;wpA6T-QtgB`$*ur$|2rDEt;f&*eg-tAh8!T&E_Rp*C|_mD;okvL^gxTNO&vbctPr z7Lzq$WF1Dh@XvzGq>Hr5Mi)+dz0+>~)cXp<0-k!|`7nZJ7|!TEK5pq+i-`vyE22h= zmKSk)JIeArxH->DpdCK*? zLBe++a(JrhdD}aFRJMpqIOJCdc#PsV2(K3O>X1!nz*-R;Y@@K_(wY(f>t3`Q_WJ^R zwWh@loaGjP(l^s2E;)J!g-Bjy!)s=UaZDLND0mLCcvvu1Y3m=|DcRpF;$j 0: + mid_frame = frame_list[0] + vit_frame = self.vit_transforms(mid_frame) + frame_tensor = self.transforms(frame_list) + video_data[:len(frame_list), ...] = frame_tensor + else: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + except: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + ref_frame = copy(video_data[0]) + + return ref_frame, vit_frame, video_data, caption + diff --git a/tools/datasets/video_dataset.py b/tools/datasets/video_dataset.py new file mode 100644 index 0000000..cdc45de --- /dev/null +++ b/tools/datasets/video_dataset.py @@ -0,0 +1,118 @@ +import os +import cv2 +import json +import torch +import random +import logging +import tempfile +import numpy as np +from copy import copy +from PIL import Image +from torch.utils.data import Dataset +from ...utils.registry_class import DATASETS + + +@DATASETS.register_class() +class VideoDataset(Dataset): + def __init__(self, + data_list, + data_dir_list, + max_words=1000, + resolution=(384, 256), + vit_resolution=(224, 224), + max_frames=16, + sample_fps=8, + transforms=None, + vit_transforms=None, + get_first_frame=False, + **kwargs): + + self.max_words = max_words + self.max_frames = max_frames + self.resolution = resolution + self.vit_resolution = vit_resolution + self.sample_fps = sample_fps + self.transforms = transforms + self.vit_transforms = vit_transforms + self.get_first_frame = get_first_frame + + image_list = [] + for item_path, data_dir in zip(data_list, data_dir_list): + lines = open(item_path, 'r').readlines() + lines = [[data_dir, item] for item in lines] + image_list.extend(lines) + self.image_list = image_list + + + def __getitem__(self, index): + data_dir, file_path = self.image_list[index] + video_key = file_path.split('|||')[0] + try: + ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path) + except Exception as e: + logging.info('{} get frames failed... with error: {}'.format(video_key, e)) + caption = '' + video_key = '' + ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + return ref_frame, vit_frame, video_data, caption, video_key + + + def _get_video_data(self, data_dir, file_path): + video_key, caption = file_path.split('|||') + file_path = os.path.join(data_dir, video_key) + + for _ in range(5): + try: + capture = cv2.VideoCapture(file_path) + _fps = capture.get(cv2.CAP_PROP_FPS) + _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) + stride = round(_fps / self.sample_fps) + cover_frame_num = (stride * self.max_frames) + if _total_frame_num < cover_frame_num + 5: + start_frame = 0 + end_frame = _total_frame_num + else: + start_frame = random.randint(0, _total_frame_num-cover_frame_num-5) + end_frame = start_frame + cover_frame_num + + pointer, frame_list = 0, [] + while(True): + ret, frame = capture.read() + pointer +=1 + if (not ret) or (frame is None): break + if pointer < start_frame: continue + if pointer >= end_frame - 1: break + if (pointer - start_frame) % stride == 0: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = Image.fromarray(frame) + frame_list.append(frame) + break + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(video_key, e)) + continue + + video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) + if self.get_first_frame: + ref_idx = 0 + else: + ref_idx = int(len(frame_list)/2) + try: + if len(frame_list)>0: + mid_frame = copy(frame_list[ref_idx]) + vit_frame = self.vit_transforms(mid_frame) + frames = self.transforms(frame_list) + video_data[:len(frame_list), ...] = frames + else: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + except: + vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) + ref_frame = copy(frames[ref_idx]) + + return ref_frame, vit_frame, video_data, caption + + def __len__(self): + return len(self.image_list) + + diff --git a/tools/inferences/__init__.py b/tools/inferences/__init__.py new file mode 100644 index 0000000..db0383b --- /dev/null +++ b/tools/inferences/__init__.py @@ -0,0 +1,2 @@ +from .inference_unianimate_entrance import * +from .inference_unianimate_long_entrance import * diff --git a/tools/inferences/__pycache__/__init__.cpython-310.pyc b/tools/inferences/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cac82efe421b0620827d06c2f12dc67c1a58cf6c GIT binary patch literal 293 zcmZusu}T9$5Zyf$f+U@-jon$~1A+u4m5q&Fw=k^PIa%482|Ig)^tRT1gTHjGmA_!+ z#BgHagL#jcH#}x|Haj(nm)Ez}tA6a8zo@ji){--&V1iBN;h)6?(HyU8b-_(NZ5wQ!AK78fNxvs-c3?wVWJ zMhd(Y*l;fXG2v*qbuWcGAPy+w5E_a%hf7>xP8v2IDGne~ECBP(ZkSwk9QuUI{Y!GJ IG%-{A4NU1-O#lD@ literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/__init__.cpython-39.pyc b/tools/inferences/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d0533cc269d526cf21afc6be4a7d3ed5b1b17fa GIT binary patch literal 237 zcmYe~<>g`kf|lOtX^BAkF^GcpPE-vln4~NrG!;FCqFM8u4W}e5i8I{ zF!9UG*(xTqIJKxaCbKv*D?cVQFVhiXW=u(behyF=W?gYie0*kJW=VX!UP0w84x8Nk Rl+v73JCNgwLB8N%1ON=)Lw^7O literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fc212a6bcade8839065fac9c0d1b842226d9c7b GIT binary patch literal 12891 zcmb7qYiwLecHVvUlg;L%kryOZzKy(F8`%#X0ax>e^?ovJ!@>eQ)os@~n5RPcB4KmJbPLQYZsD-}Bb;wap} zzx!WQMPZ7iFqLT))l%hGvo!hDtA-U(16Vx zwGm}k-POYSDJvyubX(o>o3_&O+hg^}Z?DydZ=}*+9k2$%Hks<6rl?AF$QqK6Xyr(C z*cuM8Bh^vsXqX?X9<#>7{Bi3zD8(vf^@MdIgiTZ@ttpw0SEj2o)(r9qmaLovd=~jG zma1f{bJm=McUMkTPg|#DK3zFeoww#?zNd1wdd@m0^Szbx)dg#zdcnFN%YBtMsu!(` z)l1f;>SgP)g!Nb6tiENvCG!K7E7iBHx2spJtJQa`cdFN{Yt`%4b>LK%sobdEux`k_ zR#~(b*&rMGOtEgVBWxJoTWo}lvZHLQsIg;goQ)T?LWCV>=4Yz4#7?jY)LC}Y?c2Lb zQDKw6r3Pt1V^eJUGj&I`?yxL7$!1aiU8X%yvU9l`nxZ7n&Y8)%TlK~kFBR84b2gtf z7tSr5KW#2nN^7~boOy@Wn~k@Vb9YO5r{+4$Y}S~=O>f;X7aO@eJ|X6``GLdTQoUv_ z%%3x7slr5vn#h8{R=sIfb6aMu?wL*30fmxl7D^SzbUw~I4bLpq%zVAtsFZTGykl;b zymi@2NMk-Z_hCq(zUJjnIfv?vEwf(etY+rCWOB|l&2`UfTseDob8~Y(C&xKo=f$&? zU>5G#yDPVrS05~&!B|oMQLW;*uF0Ke%@W5P*S5@D10%`j)-b+GZquxDGgssea9*88 zy2(pksa8B~y7hv$nd44!j+I=Gm)4rzAtr^xH!<1H3Sd*NW=+#U=fzW z`MAMp^cZ4^*yk|t>VZQX-d`NE04$Z=hLbN9N_kAGR&3^qj#;d4IJ}0rm<@+lOD-*; zi@~wvT%}Ykc{xu44wn63LCz*$k_uVN$p7f;&0R z(H()0EFExZDd*+h&+%NP;#6MhAOw8ULroW>t-3EEp9IqLg(6{G$7?oh7i)l0x=^Yy zyMPt38+Gmh(wD=~i#3~)iepnF%^xZ{H3vPEwjE}pBfFR*p7}}2W2KvUC#(B?Zf?V| zOA8yQZEuvAQ}>4@$XnQ0FYNbOMdW}IKz&RW8Z5K-gd%ey} z+x424tJoLv^$JEfTy#9UP~xs7DG4-nA#dZ{#*O8BVQ2fkbkzH*^^>jNKLZU1dLO`pW;=jORz;??tUF-IG%DwHr zU4>x4Q=RS%mD0$iS?sCC%U(Z=2t}wu6S^=&L_`H>i1<#|1H~ICXS{xDMGz>u_;Vd< zEwSo%y|?(_A$YmCv;iJw{z3zJZ@b~ZCf9M<7U*F(3>J&fu_}=}yOZI~mYxyMs zAU4rjgV*z}Z{&E<%?|pJlIK)iKUs21H7p(UlW(w6-s4?FAlhuOoagvj-SuNmZKK5N zHQr4ay;Sr3D3%>iKi;U8Y)JQ0D^0o0I$c-Zn6uU_`n|V>5l}ewN|OZM zPhzLpAqS#Qp$dzxuRitFjUOp|5lr~WcWzyI^2o*6@C58D*Pq;~R|{K@R_sme;QA&e zjQvY`b`S9+-*mlt6&r~;?jF#7R4XmkFl5hp;??UF5UYg~>B!p68cq*c3!95nN#2=n zZ28CbD}Pmuix{OlfuvXMS9{f@mQ<5!3i*CDszuefn$psMj;TW$zlCafQcQb)5oBWc zCxm+!Nl|HO9)xyB#fBTkhSOVy(Auh}m30WSK7#XJ7{f{?{NhB1we(z*8!w7lsLgbO zK%9+LZi42Ipaof}D&-Z*BH`4URj4({TomV&Q)B!rYC%)xJX-f${wZo1r7D&pQq8$f zU8{>e(18-BmVATCa(#m2CbjRc=AJqbzOW|5!@UQ%OSYi_&dh-id=-V4CK5Tak^*Lw z6lm{eRy&cv4&>PLUu>UyDW`%pzt8?6+$){Xx(IG2)~ zGh01EX=@C8Y3sXCO}k2~S9G@{qL(-YmCyr)br|0#(!lpgP7Vt76>PtVwFb~%H~Pzn z0TE@oTm}7z(i#*QG5A#B4KKfdUQA@DRnT@+Vfv2N z8X#I3rFEnhmGy?iASfQeY=-ea0*Wzl1hnE`=tK)M91=&wz>W_2_@Ff^MuRb-_rDU| zVsu1l9TjQ$9TP{z(RO@SL61kW^?($OQZdlqXS{`ku8LV*R z#>80Ic1*SnOK3YL+m5}`mg)y>$HXxBcvSLv1i3NE(eH}k7b^c7(am7j2Wd!OLS!F7 zZf}goWIxAV>F1a@R_kHn31w43&N+x?JqK3Dq>R1z6P5XCtrh|5Lt$s-hR6YF?XnJiWxC0PKvCUD?}e-pR39PZZa5pwBX^FG*oaMJDyf zcE9b_$wmrJ8|XE2c6%~Q=CWHQ%aF`gSju^_QV&&{&7KY{+8m5tnB0;8O!MHlA_93d zEIMpE4FwrACFtx=)a=VIF$iD3f9GbNPW-(;((BMt$^icVR|ZK@5z0r=mik?(&%dud zgHAqBnm^KfZT_6Exi0i`FHW0M&6eh?JWo6HYVv9YRxe#{uKhNVrhYQY_B0vDfxQ_X ztn*i;x#C_!X0HWbLhY{|CVvF%!I1a|D8eZ0>KM(zK=`+S{YTW`36lRWn*2j*M&~)F zqe4>p$CM#ToDPWj?W4<_O{h5iCD$qIW>ek2E+LPdcm)B&8ash4B*wUa!(^Vhbmq>uoh)b=6?>G*S0WKL#@# zV&dU+gY6*CM7|=4Z95#sSz=k1*O_n0yCF=;iQ&K zxGfDWD=7=_!!UC>85e#eAYsr@thgKyhUrJ?>d73d=atRm_q?)k{4SUW9M=$O4Z@HJ zcZQ#o5Szr#>SCLcG8Oz-2{s^eKDLmNb*y;F4l4M)SV{&{jW`wH`Z^$CCEWi)te+ru z1r;+L6@#w5YQq}rfQEx0a5wIyc^gJ{&ZZ)`jyu{xxhJf;SHVvnRPgsELr6cdPss1= zDj0yQG4%QfBwTf*^k+4>@WDdUPVoC)+r_j8-4{sGssZcKz%Qlut4V!GgH@>;dK|V? zTpN;@A=sLRhFTgdEUgzc;`ogj1Rp~k*rJi7s(l*OPNS3xq4bS|j;yPW;Ww-gtHCd$ zXVeUO=!Ly%{O8{3UbKpTs{PzB_~#hZAq(wY0LU*WEi{IlNOoKmyILunEvfLT*d-El zRTK%)1c?kAeqI`rXvdD8C z_SJA3V=RxtZam7O@I=O74RixK2Iw&s7u`e3u7v>d}kJIi3d?wyFBJg zo`zl4!X zV{iw&tR>FFR(DgISQq+zh=1>zq<9?^hs03(#x7v^E=mVOv~`5(;)s3oh5GC`pxuOJ z8I0-H1}O21kQL-v26oDz%nibYup}wo0>xph$#VO)Oz(K_?vhPA62Qa=@}prMs|yR5 z>cU25Ls+RX$i|8oBY!C@xnpcZ9Fsj(1@e%9V{9~pYyv`7_IOQWu@>yb@-G0QvZH8s z+@qBnYkvdQ^YLA2&A=9xJZlXjCuh4Smzd8FF$dT}{DFH?X+n zxWDJEieaG?RW|-ytHD~sJW1m3N!}&O@Bg~uee=864=q#ne_WWjFPLb3AKw#SM2On2 zkzG2$CO?aNv>)!%EonzZOS#_wkvpn&0z2%4l+}l!)1()cb)SgAc@p1AJc+g@T9d7* z)^uy8by7^VAHhwdhBKTJ6NkzZV){^dTFe|O&xn(U$|t=K#1v*RE2hei*%Wx3^?ni7 z=rnFgv$zY*;r|qy!B{`U+GxCeP*3J)r=Av5Y=8~GDji@`pKAqJ(rpX(zi+WA*sD|Y z+l{nNf$k|bEl#tO&-K=sL#0!4eP(ww*+ZiIgubD?pTr7fTl0|2kHloyPmKD2Je~D^ zNu1qL+cs{R=fpY8YzmyAe&-MReQkGW?hg*_j$g$1vzYHx*yHc2+(o=GzFM0*voqlsbZ&}i z^dk8U{4{ZRpPLuh{C>R=$Y~yThgViaOheO6v9s7Q)6{;y?KI~420ORkel_I62>3qH zX2j9fgcxW$VgmD!9B#cK-k|{oWrzQ>EJL zpnC~>b3tTD7t4K-YL_tf3nD9b&N-4(jJJFgyD5Z#FME(8_Fn6RcpGRA73)0G`%N6l<2InB?^^~jP zJk8px;f%Y)F7N2#GJ6xUe3=b+-xQ~CCxr_Xef%jKs6`;_as@`9C&U~3xWAORiygS> zecX>FZn^_^ejoS0CGPwlj=l985z2?ZE58vbYOSk>#$oUt+2+iC8ye+>kfOoICHBG| z7P!6M41OB8gjfi0NZ$^2^;h=#4&qi<5xr7d*M#=M;NL{~ouG_Wye2j5HTX^bv~}I9 zi|gVl_}XCC!M7W-{LX%PQCx1_+}G3w{{)x;Q2Yea_bfcKD>&If={G}4zX?k6eA_?gZUW;S#SI+Zph@hDNp=gGVA5*>z9A-q9S6uQ^tp+9 zdI%A>;7N*MH|eb(j#zGQi92)(yfB{qvD}l(vISA6ewNU~ zHtE9luR!7_#Ukx(PXym~YZtOjkarY7Any``Zzsn4L`)KAKuOvmJ38pl`4w?VuDjdc zCa&SG|Ec8OB>4l#D{x>fkaysqME8AI4-xlN=N^?_faUFzNfUo7doElvaA3^_m z2~W7J@z77*djFAq>+Z_?_6Lj0w|%2nsn-#*iLJrONsiBXSg*bTqa7~idFfZ!)}~!t zf#c4{xW3r+r`svq_3Y_$3vAx|*xR1{*7R)7&3mP)lXX8br)Mk9hEu7*ubp+Tpa}oA zOCc;tgW{7Cg@6{WSWd2rz9X`QAEhfF9IkP=85kTD?)KFEW=%%Y4xH6+NQ7=`6CpoL z%S3p=s5-ftue%=e>9THk&d1)WAI0spP%8RSx~$q;+Y1P;=ASm|xZ1jBI~?2RFBR+y z3kCScYqnF(*$WF7&Ye3q-$1;S=fDk6OjUCYn_LNjhaOix&#@_B;TsecKn;0e%zN7> z?!pIQ(sj9lt9rd^{z@wdy;7T|i{kd>#f^GN-srz3c1I!6AeLrUoobzLA-+o-M9eM> zQ^WoHQ52*zCPB^^Y!9QQ3GMs( zv*=-$OjhTQQI)$$aJ5I&=0^xXdJY7EF0-#VFc#qaamg7os%(!2BKVaa7!y^tht|x3 znK#$XP4mn(a}C)775pygvI$%Vn6kR&_gBE~8hWJAQ-IhWKGb>^rEEY9;Zb>qrdxbp z*zn*A`eXErzrVijN5k86dcS($Wg_no;!|PN72zierLgT|tIg^foM(uR%ge1xi^QB- zz2rJq%$L%cl9naJA`uTNdh6g+G*~aIFTlVBQ|C5z8!za<6Pf;>W z$qW)-hgRTO${7^g=BEgAhLT^Wgd(||A}N7C>l|`27H$V?<;O!!h)@qggSDEUSVR17 z9bU4her!F5OR~rL-=S6rS;C)_F1c{%AXptl#r=LN*c9)C;;(ciq)GQuM?FKq7qqamMRQE@KT}Bq}K|V82NT$2rQD@ z%8%Z6D6ATI-{KGJ52ZiPPpp!@rl$b>`$X*T5;655=lEK2@=TFhwqle+r~C-v;|{hF zgENu;9u+mO?#F6%uqns*Ff~l&nqEDK8Ln`QQbB&~K!{+2@;=fG5R8T$8 zi@px;y&oqY=HbuvQ=w$aXCBg1%xV}E?MME*MB<+iiTE=-#UQW0AA@g|1l-Ty4Gg%B z=w-)#=KqqK(3$A>%FP_Kf?}>W_`fD_N|Lqu9WML)szc14GuvO$&ou5N^@v8S8{ZD+D8d~$3Dp%3t;nyWRT&87DH1Wow2Dftd1pHlLx zlRkt-1CF3(jM!uLEOCafa6% zJEx9kLN=^DoP>T~$PYSVv5RpETD={ox$KAw8c?UZIPf9|B>OsD(g7&Gg$cWusFm2u zowMR#axif^m3SM8ujYNV;HT<^0-ij<%)^n*{{?k6_}WQF6!||V94*cMnP_8^TD_fH zK|(UGLZ~@n57Yxyz*`QIcPZOe&w&qpFDf8+ft)LDYe_u@;hEy!Q4;k`vl)qWoRS#K-_k9iUp zwE3A4rCjvqMr;p~iV~fEb^Oaz1TQrty%g%x3yzGQ(&+65UT=Ud!3>n(iI;uG0UyO% zkEjt>N7Z4iHy&4~4Dy=?zc~EvXrUSgQBA5dl2S&($df;+j{h{-l{5@F6Tr0z^f^o; zNB{lEPoM`w4_ebq$A6}MI{xEcGj62yH&9Naw_)@)7Sjy!;nUCfsTS8TfBG3eHU5v* zH-_ApI`tYR+Gl(jHKLmFB_aVj@ysC3=zHG`q6#!?Bcg$41SOi$pJd{3V2CIEm;*&9 zFcSPS>S%0KMJtRxj`ivNOk1A%15T?5EA6*wS@)B?j^f@Q8Fq94-NL{7pOAP;S;ea@ zrK~aaxz@&G7JAS^Ke+BxiPgRa3pie(!SRY`2jt>H;Gm5s8E8ok`k27O3|ydffd|LT z0A9z1*$_Hz1e2|Zd?r8+o2Z9J5h4mIDHgQW+uc;ZoyPSzCSrgaz{Eh;K-p;b2)IG< z^6d&j?R_jxd5p7P7^sKf0QxWxzD}<4-@*K3O)`0-QeI2LROXh!QF*4sZ_PnapGJl@7HvVOH<~Qrvg+;ZiXXFWR?pjZU^8nR7D3&emuA{u(Fp#Kl0QS@NAcbe z0d4*tsDP(>xr(gyH&o$|31*aCGJxTSSP$+8WN7m>N~Cz`@2}jIp>r67^nP%jo)6AL zkoY;O^FN6O)|KAi;la1G)&0JpX2)}71X1bs96WC3|Bf0B29<&ru)*bl!s$J)0Qd!} z91Sk5FcnL+I#@65%OObtClE~ddF+Vr*)ON3to$#KSm8iwHSi{QCL9>L5xhLCxctwl zP8w^|#Ov+(z40FZ;jp$?9AcYwo}!?@sZS~sgNsNFf|9>94o zA=3de*gCKy`v}ud2|XLJM&#?n`Mptf7+2CjwlYC;?i7imu+}T+oexs`Axe%=GE50w zb^@zh&u=WqmxuE4owVGgTIPRARcOOW;pz~^eTnqrSA(HkLoLOn9c3u-J~}Dm>Z}w` sh$zkjI+auMM|t_fxfcBBWKQ7>lCfzcJ~|eQmy~ZS-;S;ftLe=D1Bj7Tr~m)} literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc b/tools/inferences/__pycache__/inference_unianimate_entrance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd13e731299764d78bdd5f3a561714411798cc04 GIT binary patch literal 12529 zcmb_?U2Ggll3rDH^$(j(vPm`nMXB^diKL{K$eGdTZ=|94BXLJV&5G1&rbgOwv$MKc z&F-#NRy8H{X5XC6UGLgBd9{aox!W1LX#;yVcV~DX?(Dv(Q>9TXbwsmvGP!5 z$Q)|Z8?KCKimFsb%~1)7m&YpO=6D->xH4fLY3Gkt4D(n!f80C{s)_Q6%1QHN8#Y-v zWu9*5rz)o;)oF7^Lb}RlDre2JZR|Oe%p#v;-R1L@3+4q0PnBmYbLO1P_mtnLylK8E z^S$M_Di_U*GT&EzyK>3ARJm+kmfz{}mC9A~YUP@Ft@4ifj)e7>uUBrEH)MXG{BC95 zoUhz8Z&nt}h03D2SXnZcB<@goxw337%e+=5er4F;_Z9ON8)CzF-)19hl#Q|RqQ(xh z2{uvG3Q=~19sRy)t}ufgLz!=}Np^yr1mr!Yttr!|a_?%2k~}|SBxe@f+B5DHH+^F& zKW)rin43Lo%$J>w+(yp0#oc=CYI5d|leep$&5U}L+1&6qZDYQc%j4C?oHf2Gp>PhxYtIa~&{@pL`N`yrVHlgfU%PVt{F5h7vN_q$tjmk% z%i$=z^LJJjmR8r6&Y`a;f4^F`JnJO))Q)^kPMD7xD=uVO4l&E^%yqbc;z zIhLF$I~B*z`4Vt2?FSQbKKYVV%IX*D`Imq9<(FSht3h<70;0(1{$p0bJI3v@n`6jz zaXZI3nkDq`;sLi7bAIlF9M6@@cKIcQQUN}aFBI`gdbVG$SsoaVuik=FWmW;xVAWji z1Ja*Ee~VR%lCo`4B`p{#+Ep9moQBOTP_>FVVpNc%Jf^pvx2GdPzn9y#EoW{UrLAp; z*={f*LH^wKrn3*Uyj;N!E=r6nWmW2B->Fe+Zk0g_z%C0yZ55q@wdr!F;a2@z*?J@I zmeIqJqU~D+hkL%|`CyG@SMx3dyQiZ;YGrkK>7KQ;dTV8MDTuy*cX8=XkXU;E=F;Ng z(!C%xe}DZhunU0S3%V9pmY46ZK@31uC(+MCbMOCj@`p^L98<3;etbuf z`3Uj}AKIVkEV`@hsD8vswjnVVZ$rBMNGZjChsA;Ik!f#r%16>e}F~N4`nB!D%BzXP$}aNP&tAXG?sV{xF3NA>sk%E-=AMw2M6aD7r~&i zo6nUuUC-Z#e{lESdZ1$|y~l}HLGr=fdv_MCwH0&eB>^BTv1*OGc`wj&yy#6222sbi zD_)RvJg15YgDMJi=Hz`&+i4K1*I3TC1I_hVTTol3G$tsww0Lw3r%G6Ix2`1*}&aQu!5>%hQtI`;iF8@hsr+?jR{DP0d$?x}!pr zFF=$dO@O)qE{2~(Xwr*FyY zFu6?i2P(NO8x!0viN)`&0q&9MrKmc1Go$h)jhEwGOca z`Zpwco5NywA9F-27{z-`jJ4~J zfoB967vpWnxX83JVvM9t3?es9xfkl=Ws!Li;WUoJA|no?&V-1F32^Mlk5n;)_ffou z@izRy($Ejpg37vB^7{#qS(M{r-Osh=F^N-vOTC6WzF)6L;s!hFo$v=+!{BYNh%>Ad zq~O=C-aPpY@X2q0pZW&)>2H8fiR0oVQ^3>Aj`{54z{BJ%75RcTJEN@L0&ZH=+CnEt8q znbZdll#f8SIsF^cj}x8w5OzViKwDow^u)mGOHxFQ6t@d7H9R|8sukq|Z46kxOWS*a z=PI@Z#bHb}PE3<|=T*ppAu|h>U7oC}R|-u{pA9Xr9E>!Wh>`$IN#Ew%xiS&Rqhir! z;Z_y7ckAXW8$4-A+T;(>>dzo4VvB#2YN|hy+VLmaV|^`D_fiFYt_525LZEpbv|Sq7 z+o@Wl!?EjNJRMppd94gnj4bc#{{&S37apZCNj6((6@`QFve)GH>Rm@>uROnsvi~og zT!Ca}h%nsr(K_O>8zm7KLrp4c-)}sxZ!$Rp=k|!%zsOX(rEMiF5#;OH(ddBP} z(Z=QuiDGri2iMZuD z8O1^0f#oYph15Cyq55gX0!cFU5dY0M{cD>xh!!lGt>2P7;R#P_-sbu7$* ziEl|3owdRuW~$;a1`eX@&Qm(@w8;fs#K^F8rh~GCSbn8u4t7A>JtJ@r4rF-?CTC8T z!hGPIa`4^PF1lABNFEdj_C`p^plhE{(AiYj0a;?`^&UvL8Y+uvO^SpR0l9vH{?~Mb z3Gwh6NEEdf_6_|9)Py>s8d@fzBbU|^>X3vFsYwkGJi3~Jb(q8(sg_2~G)gArLsnu8 zm_+-3SjAr(0Y&-OGN_wD`Ls5MnrZc<#y>&pd0NbSzZGN#@JOqv1^Wrs)cOO_S6aqxwskk#)G2w1B)l>H2)4)LF{$UY`h8ieJfmWGHD zi^^lqF!Cd?dSb96qO`l=8F(h z#Ri^hMI2I_$*(I}BwMuGzqn81ZP2(VXvE`i47eHtnMifs8RYQZWlU2DA)&WfLE zUEfuHY+%NE*ddmAu0D)3d-3mkps+z5o3A{Q4$LYqH1GBWSRM5L2#Ne7E+(g4Z|8=k`xy~aR3sr z4EO{&`CtnkVq@ZvtohdhdCch{Hr|H(H$cV#$y9X~Yi7hi=@uX=JB)gRat0<^w~3Y* zBnuwB90*4bYkVN=_nt(;YNe8g=vEv*z)cJ#Sc z)x?18aZl@xW{DznnSV&u+MQ0IELDH@jm`pl&JlXY+G_qCj1WwL#LV@ z6*a$2HF3g;Kl=AL`wf$&PSTHbQl_Uyv9L8-clVm3&9QK0f&$i_7#2fgSd(aZ9BXns zlv_yQX)%nJr@~$#H!Vhy&WJIjXT&(tv!8WSulzMJD#k<-_J22MuH(6fXO^9No_H8* z9&S!Fk2H@qjpi|Nq;;Rgu;!#JHIInHuY4aCM_>6qDvVdY4RP$1?_>UV1UTM2E{>EQ z&`6sn{13%(mc{AlIL<&P@l3Mw=*vftgebJfemR*V{cuVgVS{WCR`ej7{DD>=SC6@? zv_58&;sl(*`0qxWlb}1vvarlAJdZR_zw$LHWnp$llP##FhxBf*l2ypgiRKjK`H`4t zw-ci_AhFZ_cg6IM+Ol?)=8TxZ$R@!VYIo|O-PhJs)cf?6HT5C-KZo%i30r)j{`f-~hj4z_UhUCgw+1EGG*BHM7{NIyrqs_grpGIdtTRHX?dbL66SNd$o*?GOs z#(ryO*v0)m{DC-%RwT!PKT15_=jS>0_I|nF!^+8vV_%&UaTL08l3l_IJxcZW>mJ2e zv+Q!aew^!V&Livq>0-oFoGE)+wm6J&NFFz{B1`Qz&%egqnD<2|cVENDy|GTbUnmu! zH?BNaU+6zZ&;xC6g6;*mg0G4Ta*as0Auta5e+~Uym#c49u1vL58pq0OLukdeAw{hB z=B$`)uasGux6(QG4rVUS|3=J~9z2XT=bCRc-z41&o;hMpydmBs$pLnYU4I^Hz9nX< zR;euB%3y_qURS9i&d}KXDt6Qx?A@J+m}T>j=~*`DuZl_9HMg`E39hrjY7#Op*TyJx zlgRGlAR)k=@4y}1$6b@SqaC<2`?wD!?#v#J-TX94`SyR|(`Zp^Ui?ZweoNLlyev=G`-&e(K^O~dsn+5g3YwQs8?rHJ%t|#A^x4j*p ztK^Am;qD@6W%_vP>^57$JWh1b`gxny&p}J>k^6h)w}9CmSHuK-LK9dc6YM?cg$e%& z;7`RwxcUI0z3G{l*n^07;7^KSWks6Tv6gO#>+toAHQxo@hM>KAUfgJXPt231f6?{$ z@8tTtA!`tIYUd7G5TqGfO-TO)zGkHrb<_WiR*t>T4vA6NR;!_Hbzp_PjPH{5;lM# z`w!L=@@t>_Jsof%LPC|cE+(OUVHXU1U-WFF$Px*g8e7F9A^q>Lq zjsghiOAOwf7=K6H5)(AL;wrp4dluOwa^Gk7pC=^mhF;Sr-Cn6vpS*}X{u8M?zC+wb z-R>9N-ltMZen4{LXIlSI+C_u#zC380P8#^F>^Y^4(CICMST_9S2E6C8#Cnig_~5>^ zaA)NM>pSyH%Yj}jyDow|@eMd_$&q^wmVcn*XtNc%Rn#nZ0eXpO)t3!z(u4CMVL!!z z&~mpLDIAim$qREV>p%4yQy))G<-EM_RP1T*kuf<{wzuta6$i;_?+QLE2)a?MOVXf& zvqT}{g>$N%tDqNAStE$i;Tf*v1pGTV9dfU6@?O0vLv9C-czA-^E_eg6K@8A9a6+%x zxoQxBUm+kjiSFA^{na3bQ-8rJ1~EG2TjWlt)qzb_ay5&bH1hPF#2MJPEec%(y64+9 zl#nOryx%x>2i_fnJP2htH{6Qxy=EA1r8;#FpN+TYw_Qj2K)xYRM{&?FmS$A!ip!rN zs7u^H%&r}%HPXqC?uglWPng(#bVIPA)8> zL{k&mPa=;~ue!RXBb=gMa$@1%K>X%&L>qk@lS-%iuef~r5%`fVpT<}yR$uGcqA&sP zLh^t;_neY5?0uU2B{i6Jr4LnUNNBm5T%0tLLHG_Dc@E*9M+F3 z2Uelx9RzLLya3S{{{yNOl@+jzXkF~p>n_&ea z;|9%1dlG3zJLaS@((&=iY&YT`fA`^&hu$Mj-d!_F5@Sj4;+7RepojwXX+zIeYHtS7 z&3YwQC7n>^?vo%P1JW2!LN|om)YCGw0=epYz|OS_q}Wg*iaW!=kt=RH#*?8zxdw^Ij@Lu1e@F9zZ56l#}!<1wPA;A0*`xm(#^RM9t6-o1my0s(V3#2vG$dBP`8S&<%1<9xNjTsw4h> zH|SnQk=vna4tmz4^4|N5c<{ACZ8Ep+^G%{MLR6v~b*IeCu0^L%sMCcB42-;N@da=i z3L%KyvnkvwmwJ#`ch_-u2ADrIkyRTqFhw*o* zL==IYhzZlphFvmn5vN7v8IYJe6Gp z$8m|IX6JnVPpMo?-k`8K1wu&WAU?avTl5j}>Ki0t+ZMk;-|==^K`N{gD$gJS?HweD zY55$aF-6iyDuhH|EGrg~|qGs3D+;ckupE5iMpa4Fe{RmkNL@Q07zslp*{<9Y}}M;o-5HtW?b8lnV| zk3UFmN`Bb4q;OHvQyg=M7bUFp8y_o z6R+E|hKL49rj%v=J5;M1ApykSA@FvM)*6{0K@yR*F#ZHJ>o7ReX(5N93dr+b=z{$4 zDH$Qa7EF6UfXIBc-kwbU6!iHsO1?)4xu&_Gq=|&Q^>D6}?yA?k{qO@ChTDBRbh1m4 zG`qt%iQItA?hVh{X5nsZ_IF?`DXR!iV3oMIeq_NS#J(8xw?|GJEY>6TL^Iv7m&=BD zojTO%f)73OA<6zuH-899JOg(;4Akt}%Uv)N`8tw=y_U80(ZP)c`Th(5(ODE83{>hzVt)2ofCNuAmn(?+xb z{Rmlj*TzrjO_+|Xz>a0^f>^F=I?K^zF9Mo$OSF~DeSax}8v#FezV3>?LZ#=sNc zzRY$dQ7fr(n#tu?LOc|!CzEsmk4Mp5T=`M!(0+6W4wqGM^OD@3@% z45yTiClco1epeZNhcKB#$mfWlyJpeV^&kRkOCBQSMSy8l23=vX(l#KMl+^lWW^e6u zt`D<&4+ismJ%CIitFa{teeN{`GRg$60vlTn8>PPl1xdHPNew=Cx7ElXNC{8xYn zao5YrU@fVC{vRp%OM=I6b&aCJAJRtxhtYDlHmoqdYw`a?2)*Qy;ewOtrnKvVSnXML zyHb{h5~o-m%=LN&H@>uLEr@)uaz}=?IdL1iKV4DF(hapNwnb?Z^8ZeCx^M{u|vbdVuy7>OwDh2iVJrGm%4zF@Xcui~ulR>^qk$ZPA%W_t+LD!3h9Q~*xp|B71Z#mv`np*S0!)8+1JCd0DQye4~kud*B=wgyaBssr4<*9w!D_W&I9 zht)e7N7BOks7`N)40S9lj5J~TDWUz}9F=$cvU|Pi*o;FPDic=cINf4eVBeL|I;Sfp z(n^$8sI)Yt@hc5Cd2vtPU6TP4PsNI_2F>HL%O kBy~l9Dj~mcWpM9LN;w5Z1gxGIPs9(upZG-iM7jKb0j+EH&j0`b literal 0 HcmV?d00001 diff --git a/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc b/tools/inferences/__pycache__/inference_unianimate_long_entrance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dba0b9d99637ba72bb725cb0b100f1645e2decb GIT binary patch literal 13012 zcmb7rYj7M_c3yW+&tNbZ3|@mL@o17D2|(oVA&QbniXsSr6eJ4d5}+uJNOuNvdj>t2 z7jO3fa-HOYO9Myc2Sn; zU)4sHU3FIr>66){S9{KLg_R4o(wjb|kd7zTcre)b!d9X5=9Sr#l zRfaW1RVpLd5eexik5X!d9bo30$o9uM;; zD#s+(ne42Dbe88TFJ)gsKEb-mCqa1%`6TNuXDX+&rzJd9K2tfHJuCA)<#UynvoFhh zZ~1)XmFz1r-&ej+c{TfL<+bc<^1HwMdgWsFV&ziyQsr{?vV;wk->6*4UXl59`OV7J z?A6MAcD`~gd#$pNU8pQ(7lBjRV0o#slwFc}t&H)quC}i4Dr|@if2{1N*&A$xjee|T z-(q9z5F2L`MU73eDK=Hq3Q=aT!yl{JWp;!e1oI88cm@nJwx%HfJgV!4MONqIgcHXKw7Bd=EW^u#au#EY7 zE{|7;IcdCSamTJzjdPhZ#tac02~kI8!QgSNVN`ODjcUy`8jb}Hwqq3RvSnD0@>bn7 z?5dHkRqACsSIt|-w(V}nT0$O~#M}oVhuXTE191-I^~Xl7&`xIL+(cr|FpLe?tzSHS zdV702lauYt)Ohi9Ip~FR`sVV&(#q=6DYO;zcdKQ~aSU!fY}g!qTz_oj>S#$mw~qFe zbK6FZ8@VF4fOBiq(rs?LcC~oYaB2m2JIAfW9J3vl+v^SYnL!GhZ=kd71Q@1V)i^T0 zYAmlFF|N(8F0YRf|S=e{zfhu=MTJt-M{Z^XOEy*vJ(vqgdOr zcolsy>K3oq4o#th#<9d)*{;}b&Xs_JX+M~d(}|~~QdWMbo__i_pMCb(tm;RXD`1L@ z?mlD{ykp!d*K!QGPHyEmN410=UOX7B#hjbFo#VN3*(yJcfFbY+7nBZKU2&d5P6-sr z7m9>&EVofN9n1y3QU$xp%mQY{tk<{;NM8(~mE`du?fPap{hqoWHwv3)ls~@A#dI%h#{pU4>+Tt4^yS z4`sgprz01|^9~;8`$&Y+Qg)R^<-zEE&C|SytBQ!L?Pz>dXl`Un;iGOuK#8|>0hO*Q z5Ai0*W#pLZMoY0Rm1menIi_A#+>RYZ<|D|*U2F(UXVG14M|C4+B80?PM+oV1Bc&vN zn{@!&Ez?xB+wCd!w)%Dzf&ou7-D&ztA(vtun;K_sKZ~Lr%1%TDW>s+qN@=&BC<$87 zSo}$dneANhyWX8&UBej8FD_z;%C&s1yis%9efaHLch-Cz>+2m(qva>=-MVvg(Og~5 zEJVls@U3L8!rW;T{Ua#0Dqg2f@xpxMuwjViRdq%zC3Y7r7o>pO2{>W}>+Kbh=R zwAVR~q*v=#d)0)7pPE$r@r!9OHLfPL6rf2}j|_dLtNaa+=1Fbr{n0Y$z_WnIxrwBx zc$%vSbw|Z=7{PLgc)HM9s;iYCQX(>n{a5HCN;-UFXN8UPM3c)R1}adq2tlxy>MQ&( zczy>uirEz@RanZMR<%)q(t%{eu$x#_#?OKZdMxK+n%rD|6I8lg!Nf!>IcL-HyXXZC zC1I-V>+~&WD@ZO=`Tk07(}IYElu#hfD&P)TbUH?74r9TW@$u9^B3ouF7z|s1Ze4r& z7Xap^ZHNs0oy60HhcvJAUy!=!GmXn9>VrS`5}&BeKe83C3s2I;__QKI-3`Xg@VZ4e z(>5dgysNjO0(RlBBD#lR1&S`w4W*#BVs1w(j;B)q>UNY8yBI@}B<#a9mBLtdwUQ{S z5ynd~-RluCuh;Ds$sN_}6Hyk05?B>|EH&yYyEy&_3!X$Lms zs4}W}!(zD9&pOf9UlGHt0oIAO4vUmGB1S^^t{4g7Bao~Ckrur>5lC0b8x^Bs0DbtL z=;7Zx;4DVLcMM}OhW8G$kTp z3Ug3e&P$@GhYCo{Q~%$?D3z7X>pt>!DxI{W4#4zmg1#7zoOl6tSj;kNvMbkgi`$6AM2Uu#^b;w4%! zs&Y^H09?G6e)=VB)AGHoZY#nN$Hj~|A!fy#h!#4ar}|f(k|Jy-xmAFAcdSgQUgV@F zulxE+4HgoqvF1I~t&zD>;JJ!rLWz;_>Qu-+BO?-aYo2Vz1N_Orw9LW!gzYPf!A5p1 zzLhIe**wY=Ef#FIkvli8J!7U&{bZ^P;Q8NaBt;l14`QDBb*Z1fsXc^_Jy3$*(R?j) z#@8GN`nDI_h+Q?Mu_$}grY94Z%dj!&NO9%c#G2|!E6rnMjt16U&;uuPSsDq>6=e2G z@GGGH!e;VYz#cS--^M3Qu&%b&95jUA1@_mezJip~G)XFvP$FRYnHotq}*3DaY%xgF9oWq3rdL0I}ua|0e)sK|H zAzJ5T{6s3bN4`!q`1+>x7$#K(=7Zk>Lm49EW_zTxkg?%dFvJUdp1Mc-BwwKBM4J|e zS=G5D&xkOr(cD_G2xA#LXSGoCV{}Hyxw9v;DUxc~7%-@1V^WeQoiOD$zeoI1TecgB zi#%KKH3BBk$WRt=ppa6(C%%3HXRBJdK{mbLb6~hjs(@P?fol2*wk_+$YVi|c=aA3F zZQF!lWj%r&81^-r3=RR-L+Uvu1;P4Dx|&+mC|z4=|R9yrbr)Ks!F=mbRBph8m) zcURaD`O%AcXRzI9hyv2SHb_9}cZ4m>Mr{*wMc)B!+%^M}Y}W>zkR;T}CMz~$7+rMB zeniK+J>H!($dV|VZmUK@Ot(_c4z@wVW&-SYlbyNpX?l!KKyJ_k!BULUbHjjVVo6>DLcKbV!|4 zwU5#fd`IH=?!}MZgdb89s;(gq`ZV(Rl4&nYS)wANHh_?EOXYh(O>{7DkC zH?GoeL>p0qUkrTdN43T@GM1BS{3Gqh{W||98k#2)Z|{$+;C?&=bN&(rAk5!`(;>`Z zoDTKybO@7=PKVfXU>L&y)TP0p3!NZzI)urrnz{=UuoT--k&hwQfhV4ZGKW#9nw@Tf zuy{I$vF*cjQA=IU12~d>+C;EMW$9+?Byb6vm zxLfnPfZ=^j`XFLnAB%`SbKw*9;j4gl6PD4eE^dJne;2ZaJR5>R)gyDmI8ZD}j*H;f z4~e)A_yjpYVL%SBF)<)Z+5&mZ=>R(vLL5L20g|q2Eas&}f9VDwDjP?+K{*2xt+$Am z7z|7*(I504$BO=--Fp&=mbE}RAm5-uE0$> zG}9kMacCW;i zO?Z>ulxKK{y(8iX+OQ@jTX$IuYfj1%jtS$>e2{68Mh$sFm26JnALvOyT#gKYY%T7mps z*+%TI;v`iCmHJTi=Q2aOFK+20q29J$l7Kj+b(2tN({x_leq2nT&J z+WXSVvsch^M(K0?GUN=t*e_$hwhX(l--d#iLM?JufS;my*dOyV?A3ib6Fs!V;pgW~ zAkcx4JBL*~MdkO)PNA=78R@s6ufH6Q!YF2Eyj7ra$8op2RivYHFc02YahB@$&b=@L z5$7B2Gw?b1Uu(y^8zmdM<@G1(C)zh452(8YzAwwweo?%P^>hg>Tf_Tuu-ZYfiGI96 zP`P5y%GIov%7?I8LkJyiLr4)i!aFa{hpXzm=qy!4hFb4dv2R{sZ|+3Id3F`^f1c*H zbc)Smc6FW-=S%nQN4;0P3*M`w!!fR#;uUd0yh;)XY>i!e67yaYw7;Ax)x~RRXbkT> z8+6|l$7zS%Bt9c-uo}hskn3m^`b?bN#{ml5xi%bJ0bx1t0dA@dm)XZ%leo+tjxGFZ zl=9(k@mHfo&3pa1cHEa`X7vSsyChf24LO#-h7o#*v5;@?lDH%tW;5Ppmy65dbzmL#7G~fL`F(N!`-(X4y(#&? zzCrmv17<+JyIXF!L&`ViW^dQ%EH%WNpxeT;jptFYD+!XA$9S4-Sz8?~v}Zmw=`0epX2EI>9F#kJN~#NsY?vQMH9r{wCrCN&aqr+QXU!$+iBTR(%f zF@dkM(hgg4|2DMCEjA!VVV&I$tg{0N)en@}u& zi)_8O$sUx@kHy<)S!flq_q=z+71CnvSA%!!XQks1kT(=Sz+YnUZpXM^6K{(Nnsadl zE~7o0ZW_6-vwKhC7(4yp?hD!`>U_6d+rWDIuGB2=(`ch?*C$cue@WT-Ig%zf-TFJy zvKm~4wFW292h9@+1AohVetIK_9>@@(g}{S>FodL7^OFm=@0trYmv5Wz%`d@6RL$qh zh?8}!!*@;o>@zq!_&SbKn}Kgf&G0RNo+1|bSyMaaz$X&79^vqB);60-9Nf)gXU?&V z`^at1e0Xdo=j2_xV$C`ajAJupYs)HEaf+RFF5Gx9L@tj}yV!oLJkIXUo~Aei+X#A96x6GQHS=R8bD5e0L| zTQ;UN#{%K_To)NR3pqs&eyD!fj668~VYB~%am|8rhJxgdbO##E!IK6#JPZfHbqWgw z!_w@+tk^D~a8osPBa_LXkq?O@r^G?h99lODM&8&kwvAI)jCEuS^x=0&e_i0~!EmbU zet#J=Qblhl+7}?2BhQqc!PjiSO#1fVuaGVg^1Sqgr4M`pMO2T!zj)J+h0c-GKDqQ& zAtV?Ex)JfCC>d*ZFEls;UweHTqduK95QWPexR=bTwSCcOMhTaBDjjZp)tU|87u;nv zZo*;Oudrx+I~*yH_haqQ zDl*q|Zhk{XHfIz3cZgGR4<*B$sr|3O;RW|rP>{UJAhMkQAzJyPYhZ9gtA17Kj9Cl?UkM9{ARnWB3D zhYc+7HVs7XaJuD!m@qp$`}2GnjpoNFq4*|0fy9qMNAWqz>C^;H5juXFl3$^uo01GA z(ocF0Ie9C>48+2Zhw2$&MFfkh@aC>#;cdX5yXkjqg|pI`MX-DuzoK@nurX6ml-xF0O2lK9O_A*Qs^VyBw!z zC5k(KoET979wTg-n0JAloyU-o($#VI;g;WZ9VBlB>fG;MmD+yqGvdI53{}f&4VPDl z3u!<? zO3eN)F;mxggWhp+V#-aTgL0^pk4U&xS+^MOK=D7OPtC3Q9n~6!EywsML6W(KTMO=V zASSO{Skrz2>kC^OLL!de9bPk{p@t(}zq?|&e3(k=iMc zBz_$BjwyNmSAzXH!8+OA+GulZ*9nu%m+ktVxJnCwf1BW4H0I%iBAP?;DvH@1fcVCzXE< z{Q1vQ@`RFKpyU@R`35CqB4B5rKmn<_29*C^m<6>Xgs>F^45Y~VT|qoZngvicj=9AG z8^P~u!ml#NQRSWo-03u#jQ47Ia8 z$`Aw}f|~l;BN+iG{um?YU=aMyz1$h-*anx9F1JzskBO(6_tk=*tQ88l2#3o!Fdq0H z5@PU$?V30Sg5TS|zniphd)t?QgruK`P;tLAus$d_tgEP{=OE^?7iW@EF9*Z#!1zkl zo1oirvdds0tpZZO*?1m?G>InmO-eP}4uwR1CW6tG;X5LgJ7{2?{sVMTMq>DRA{7d7%u-%j@}8p6NrMOB15X7rnO$=DP}@37)@Ri(taHKQ6h%8OGJ+(fJ3<- z=`qU1ejL0Bl8h0res%J*WHgAtAO$@|XQ+lCHlyRZ4MlU}RHHhE7*0%&t7C}B^lOud z&BWD=dNM+F?ET`B2U?>CQJv!`7x17MQ4(b*|0dCyKx8Ko(Pe+YUpoTK2(=!iDXwz_ zW%NkEfm$>9BkiNf|K2+l*T>WgXvKIWg`SU~R!tA$MB}Rd*IHb|7|>7uYyAgW-#BvP z>huekSfBn`8m&$sRs~ubm&o2bjXqJ|b&VkOi%A+ex@zqIq`TvwNXd4N0ZX)D577%6 zsb2Lu`Y;6C7=9_by!Dy(Zzuna{!yYCmsh*=AIo{Z{!GYIMeY&f*isLDx>cJCUXd3^$1)noQC=45D1$Q^_KXJYt$B#6c1MmBe)wdKcgEG`Et z$CPFkSolPxmmIQ)5G3CS3`VlN_ya;kp)zp;Ppj7b$nE8u@&*8>QO>61ZNm&*Hq2m8 z!Iwwp~zf{e7z$sEu21_0R~Y!{dm zn&drjcqJJYGw&gAUqIbHOk`{GK0>AfWT0nZg7g!HF2_j2AUi6rS!VWH)wWAZyDOUx zigTR9F$u5{%cz|XQF$^lr4cI)RcSj1#$I1QO?P_2!ID=fk|;hT&PbG&qDDJCl53lNtDDvA zmu^*468Dxj31_|Q*z(Hu+DNi9L386QFo_W$Kpu`0AZy359V77%Bk_*{f(+uoh~OXz z5+gv?&Q88_tGd}5%_Kj}M%~9b_n!MY_uO;Oxy*KTB^3M({Q5`vKlwmW{)jT|e{p0k zC_ImWIkh znW2E*aCt;iRHZzc8I_P&X{>xCb0olymnSkugY>boo;e<*Ph?JjYP@u^JeipcV5iEb zGcN_{sq$$_bviR6AswZc%d?r;0DA@{(?};+XX$KtE;A?L$RObD{ip=4}b?C3B^`lvye-XO_z=nH7oaFRhkWGpjPKl`voCV&m$8!UovjhswU1xyFXr@P|rf zjg7ETHpY$=G&asA*hE3gN7zwz>_av44%69jlzEq(Vkg-oAlI38TbVwcy{sup;_Qr` zm|3dU9&@X(?dVgvX?=cf{@kp-ShBXVTUq@ouhwgCC1$Q$IkRG$Os`j%$#rMj)E8^n z9HIa-tG{P*+p1Rd`ShGVMFma-s1ws5@VHvn%h|_zrRwN)+XMy6*7H`$)XhgZv*zel zMbA~swUU*sXws5;Vg>R9_!V7doexhBoZ^au5UX|?fluZ z4<9~EXJtRrRbDt-@<(Bxy}q`zvVMEz4El=lcPb^*wsmeks9PLk+4_xmu{{vuWhc~Tt{wIUtE7* ze|K$tc~&qx-ivL7zU*~BvvE$cs0&p!L}&p!KXTJ<7pWe`Oo z>^xv)#8GaRs#%6q2RE~vqgg@^FCWg$S7HreV$RqO`GVF|+E8NRTtXyKS|AhMmou-h#x)QbxI6a;zG)R;@5_0jA3UQyT>< zZ){h&wO6e;*^==}u3AD5M+&B6pj9 zjho9W*S+}4jVmk5%PSjRa`DdQO<@ zSLOGZMk%IVP@LGlBGVzH;|@*%7G{wHZC`alMk0VjSuB8bI-z2cf0xC8?UG@tGUjv_ zdm6n53c-MR(hAyTzQGRH#e>?8@JaoE6)f3W{Fm6 zyqdGUaF!SB=|L}IIcC}R5|(XMuwW2HUYJ=qhtsj@Me8+|bxcpI+Fs18>{`59;Yq@T ztcv4BvDkom@mkq3!0J0*ic$}&yu`qxd!`p*=2pGn^%QD#+rWl1%8;Q}#q<*Gg`6tS zZDWz^EbAqixeKWUZmUxZD7FJ$*2*<6V%M?{E7M^w0lqhao<(;cjS9A>?s)3%CklTP zo%_+PrStdh*bvY6Fsx$r-kplISh30&$vvl9EdgA~n;e{Jwk3^v1)^A%GD_7-0i5e# zOgS@Md+eQP)%en4ucKvq5`m)isC{aWme3MvLQNvwr$yDM8rPC)3b2$mr1G~=E=QZc z`A5tg!?%RbzK)=vxSFE~bzcQj&w;5!Zdhmy)zOMz4G|iJx+20O%0Lj~Fob&YM3c;& zL@AVJA%cK;!t4APXnq1o!>qF84s6H{vr;cZbYO*}H~`EF<7ZI{(kSa-x}9up2c^PR z850~SXYC!^>!b*BM#5Cf3sYXsx*wdU`hDf>j)|oXN~@Q&+v9YJ86@~*)e4VrS~$$ z4B{ISNvB67Nx1Ao?Hg4_6?a$+7yDSpPUxx1XT@-%pLL+O!y@I5h>;d%Rg46f5pY+( z7!W=CA#m7#bVtRg=-<&$`?o|lg@-MOQP3a5oQxqpB964EeODX_s2ytQQXo0Z(ABy4%@=oQBMxr)A-T}G z&Nxdx(cF^~rvTUW0&cQZFC}q9ZS_t$LyZxvcejYK=%|7aD|irgPk#ygr7wX`eF=Q} zOW-pYafts7F)2a9Prm#h-k$bhbIKeY z)4DhBPTFi(@J_c#lxBiU8+bY~7m@0>7qy{1W7$rFf)}I7VHW4KR9?7>Q_t(j^Jo+%yD_1w>vFu*B1{K>27pqpq3za}2Qsbl$h057SUYJ_& z!aL?;s9EJIGrbs8XE2)MbxWS|R~V{!o?k*Id%FSF38NPFlX#eW=RIQa-h|0;=I9*hASa_LP?h21n_*Wf{1+k%m*?TrQ0LsrswmU)a(A-uJk&giE!L1sH>#2(4B~qeuIXb?#`eiGSh=~XVo&b z6BO-jgZL%AShI%_%fMPu-X|NkjG##8Hr<*0VR^~2#TaH6!?GUH?XF3?gC^OpI?&dP zgcuGK*TFVu&`kn&;r5v`pkHTYDXbMPWrz9hplGW=;;?|%94;ZfjuxTkn({j!OANi( z0|{4sIWw(Ec9HBMx1QJgf_TY*b?m>1K+#fA;po?=#?>KJ4-Kf{P#EdB+Jhg(gdI{7 z8ld>X>HyT>1Y!!Q3`#*6ZbG0Whe$o9x&-w@D9Xr%NpTK-<0^jIh}!%`L6?4@rzNG* zJ*n}(h92ffRy6(1~^^@NwWfP%8%fQKG^UB08iTXz1~XbF8RKn94ezz^>UU1EsgvL8aWdh@<3j z=Y(^T#^jtVPU1Vox)7cgah7_bBK49pRh(wsPgGn+dSs7JHfBf}ry7kJ)+?0nDId$D z-+8$)3+iW}qINjx##wxG$e)7>o^H$^C{DWg3i7YwdyVxy(F(Y=yNNFpx6hFSOt87te#Lh``MW2mo@)DAdNNc zXGa3a9|Lj(kbz2=Mco0>SG)>{%EnP|P_Dp4;~LQtgT4YL`ux%3a? zzl>1|y;!4ltPsz2J`LILIq!%*p%heh^odr{M4#+wSL0nddbN1{rxoYMx3L4iN8o;ZsBpA<)ncWK1#r1MpAlAXcD@FcE( zr|_L-Y4quRa6$yqqg76(NItwIjWpfo;UlVXy|-$4Xdj<4b} za0ZJ0*(V`)>bcx$$qRG)nruNW-lI6!CF|huNp~9jd|yli?L?^!aO{lp6JlmxZ5Rho z?ZwL&*(7F$+MPaZ_k}$bu|I0xQ>}5#qEAQJIgI=$c<^iDXcL3cXX&cV=F#&re!rfo z-}$Oarjs-y=|Qs-BYG7hit?`kza?Wmm}7r5di!V^IrbX*&M5rc$aJ}e zFOE!awU%ZJtv>v|IEGf_x&VKSR-!fQbL{n2xgTPTra1ojl@rGxQ717{%{@l-TXm0N ztmha>I2vn=&jjlOuMK!GPcx5ea#y23*KU6$+;id_weQZqum&+(Z(oBiz+Z01JKIGI z(&f!3>eKLdijafF#t!X%#3kAd?kh-@#Vbg?ihfnl=4(y@+>@v7x zf!4ivhAm>{DBisnabI`eaNi^ek2$Z3*Toy+P2x{rZFc2J)I~X}RdmGp0mu_~ zfeku0#A!NzVJ3<5bv9T@VE4#fHVSDc&b4qwiJNc3!MG6A`zeV#)`m;BaNn1>bQ8yx zK8jE}_$_}FDQNCnpX~Ns|D44TkfK(n9B}+4>MG6#_rqVZOClx{w?>SvnwtNTH%N68dl(v%wK5bFN+0t zQPP3-gZjS)Ouvk~8jphWQ^r`E=9$t_d@L4GZV%tr@d^LT5|@y6@%xo?98tss%xx3cT@&mEq|Jo$mjM6EV#43IfLxLC zXQBxaS6b3%3A=DvEWwO6=B|M5UlA+d%~i47*cYpGqC8DJxF>h^vXoauo!VJP3m=l? zZG01ZV*p&Ne7 z*BHgS-y==%I%$g%`fYI?JqvV3_MUq~TqX(U{Inl8{;D)i0`j&32oI@}eq#HN>Ax^>4`y1sVHcyDn9)}=}= zSAsV!wguZb*~`g7;$qpvt{%m zB5Uw6^&t-16t1V--kaR0SEPgU(Cq;$R^af^;fclQbhv`TWit!&1*{<+nRvpE`N&!K zqOeuut%4UNQ-nbVky;(tWI0=dTLg{)G7ThfBX>-L97SH(cFY<|$P0bW**kt6mMNX= z9wnG(s%4mBe9H{gsoP|kful3<@>tB?OKiehY3gt)>w78Ru4AYDDtl3#rpJEHv_u!V zmYr=38EI$hBRd;5LfqAa_Vb|!$>*(Y(;-e5LAf}vrNHU+30$;}iG?AP!RM?_C*duV zR;Rr&pEbVFxJE`s-htpp>IZwF`)5Dc>$|UCF=4JDuemMFi2B~(tWG8n-G;ZFTu%O^ z>~&&IEC)~+uJ*!uI-N!r9}q=OfsJ5qXiLxQIelAysGqs0Zy}kdjMphGeZKVvlc{cb zeI=|%1!EzfppV!ad9L;pa??IBX%&P8LmElQGSeH>KC}`P&^-P=z2rp$TS%%^T-qn$ zI1D`B@C(wj0k+q*ROfK2HODkYW17|Bt4kjmoQ#V3@VvekAzb>IG{N;&sunCnFrZbr z0YhuuwEixZ7~_gajF<8Y5Z6}<&NdKHpL^h&QXlhwfPV6SO2I#&;2$HHj_}{4^dX9% zKK>5jLwE_=%3bg*psGhK&U_ zy%d8R&yafqyU+GQCQM|%ji!R`{P#<=`~klZ^RzAh-bh>KImX(Heee_a9^SL>OItcT zfM8BTT2XflF9f;esgL%;>2mE=FS1=PXDcK{D!ls8i^~TCn3%oUCHPL^Qmlig=&itU z9oxH2{*W1+p1r_x%#Kfgho3;F`AG^UDL93|3qeBhDN2Q@3w(w!XDIlH0&)d%@(KEu z(>bK%%L&703ojlBXE-h4Qm(+hyM>Ln4GZs%7u(L-S;yi0uTw3!x?!sAqn!u_tim5` z8wGj+B;`6!Q5mwpLegD!y*s&>VYrZh|x^7EZ-^_+=*pwk?S+(x)5>Xi; zDv_P7KhY6C*)}5;V?gO6t zu!%kGC9uD6u)&pLdtJf1BbsWM*28z!O@|Ls%`jOVy>4lHJUn6i^DUMK*g3br$%*WB zn&q1F*vQd~EHBj(b9rMMTuG2+E(N^@E8YP9yavRj6Mox(Nntv5!g zW%+30s_FQga^>-wG`9-yrak zvP~nO&EXLU{3mdzKsUgXA^7FC!0SPp{3KrDP_TLt$qt^zQI2piY%$W0I3v;~1!VN) z|BjHI_G9dSyj`IuDNx-ww@g00&i{}~&_e_M8U?RY@CF6{0Rem%mAoNo{wIR{F~K@m zbBncE%xi>6=1Nwr$@Nmr;GYn@la?;v3%H4B8_12re^p}PPQeol>@KrL8=EvFF9BDj z@7jZ4w-}raBsF|D4>nKJ?L$tBr1Y*pwU90hJkF@pgGJ>(3;O&QDEKY~zevF^QSi$Y zkfuQ{4R|0(TGnxfH2pVd9lnWO-xDD@-|O@pO;Wpn%(9JL=BpZBZyUz&w?sFTE4)`j zKcKSW!1j8B;gjM334n9n>uFO|WJ8c)MAb{Rd>Ou8womlTZgjDf0Vh3NKJxO^0Twh2}3L z>#57!M^uo7x*dk53>qq9tVeeMp~BQ|P(hn;!LAhOLrtt zE1_~)+tugXVi8(nd`tLja=hcUm|BO_q=%Gv>4a;vg&)jKcx{Gz)hu40;q?^EZ()2P zKLwYC$|&!{WD0JwqkQB)bIwH+o*l3c%4ls zB%RN!hi?NWz@z^IpG`WNLW88~0jmsYoN&|X>hujSW*DrRGYqMU@l$|%v8tVx4stp+ z_@@;72?BWf@!Xp`Qgid-xL}w3-D3DYb%XyAl?oSa(yriw-}aTy0nn7H9LIkjK+M;&;^j0XyMfG!n7uh%cx_Iw)d zUb>_DkFEK?r-laoLjL1iUvs5b0^Oej@K>mC)W5qyKd~xREC?wznITC5q6qq8T0T6H zZv>#mqO|S57x+^uc>IeCrt|d*j2G1ksdsH@ZF4mf453oNRQnH|eYLFzFAGX`TWkEk zQ41+7dmT^G)3^vh5%%AnCj7GU0zewS7gUxb#8HB^O16Q6m(ftlc^83u{p#%#N36oT z37PVdzQ~1o(nFYD3P@@2M&&EfbhB4&C%CWL4fxeLuKHLOsG%ja&IhReAO%AdkS;9M zb6=nA^-I%pqhQ|T8<SQaY{W9+kY?#_KJH_u?1)?py@RDE20Ta4&H@iO?y0?LQK) kB+TeXC1ef`_vVkpb`k;#SUBDjN8+*Zt@t;TZzym4UvZ<~c>n+a literal 0 HcmV?d00001 diff --git a/tools/inferences/inference_unianimate_entrance.py b/tools/inferences/inference_unianimate_entrance.py new file mode 100644 index 0000000..14ceb8f --- /dev/null +++ b/tools/inferences/inference_unianimate_entrance.py @@ -0,0 +1,546 @@ +''' +/* +*Copyright (c) 2021, Alibaba Group; +*Licensed under the Apache License, Version 2.0 (the "License"); +*you may not use this file except in compliance with the License. +*You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +*Unless required by applicable law or agreed to in writing, software +*distributed under the License is distributed on an "AS IS" BASIS, +*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +*See the License for the specific language governing permissions and +*limitations under the License. +*/ +''' + +import os +import re +import os.path as osp +import sys +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) +import json +import math +import torch +# import pynvml +import logging +import numpy as np +from PIL import Image +import torch.cuda.amp as amp +from importlib import reload +import torch.distributed as dist +import torch.multiprocessing as mp +import random +from einops import rearrange +import torchvision.transforms as T +from torch.nn.parallel import DistributedDataParallel + +from ...utils import transforms as data +from ..modules.config import cfg +from ...utils.seed import setup_seed +from ...utils.multi_port import find_free_port +from ...utils.assign_cfg import assign_signle_cfg +from ...utils.distributed import generalized_all_gather, all_reduce +from ...utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col +from ...tools.modules.autoencoder import get_first_stage_encoding +from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION +from copy import copy +import cv2 + + +# @INFER_ENGINE.register_function() +def inference_unianimate_entrance(steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg_update, **kwargs): + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + if not 'MASTER_ADDR' in os.environ: + os.environ['MASTER_ADDR']='localhost' + os.environ['MASTER_PORT']= find_free_port() + cfg.pmi_rank = int(os.getenv('RANK', 0)) + cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + + if cfg.debug: + cfg.gpus_per_machine = 1 + cfg.world_size = 1 + else: + cfg.gpus_per_machine = torch.cuda.device_count() + cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine + + if cfg.world_size == 1: + return worker(0, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg, cfg_update) + else: + return mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return cfg + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + +def load_video_frames(ref_image_tensor, ref_pose_tensor, pose_tensors, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval=1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): + for _ in range(5): + try: + num_poses = len(pose_tensors) + numpyFrames = [] + numpyPoses = [] + + # Convert tensors to numpy arrays and prepare lists + for i in range(num_poses): + frame = ref_image_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + # if i == 0: + # print(f'ref image is ({frame})') + numpyFrames.append(frame) + + pose = pose_tensors[i].squeeze(0).cpu().numpy() # Convert to numpy array + numpyPoses.append(pose) + + # Convert reference pose tensor to numpy array + pose_ref = ref_pose_tensor.squeeze(0).cpu().numpy() # Convert to numpy array + + # Sample max_frames poses for video generation + stride = frame_interval + total_frame_num = len(numpyFrames) + cover_frame_num = (stride * (max_frames - 1) + 1) + + if total_frame_num < cover_frame_num: + print(f'_total_frame_num ({total_frame_num}) is smaller than cover_frame_num ({cover_frame_num}), the sampled frame interval is changed') + start_frame = 0 + end_frame = total_frame_num + stride = max((total_frame_num - 1) // (max_frames - 1), 1) + end_frame = stride * max_frames + else: + start_frame = 0 + end_frame = start_frame + cover_frame_num + + frame_list = [] + dwpose_list = [] + + print(f'end_frame is ({end_frame})') + + for i_index in range(start_frame, end_frame, stride): + if i_index < len(numpyFrames): # Check index within bounds + i_frame = numpyFrames[i_index] + i_dwpose = numpyPoses[i_index] + + # Convert numpy arrays to PIL images + # i_frame = np.clip(i_frame, 0, 1) + i_frame = (i_frame - i_frame.min()) / (i_frame.max() - i_frame.min()) #Trying this in place of clip + i_frame = Image.fromarray((i_frame * 255).astype(np.uint8)) + i_frame = i_frame.convert('RGB') + # i_dwpose = np.clip(i_dwpose, 0, 1) + i_dwpose = (i_dwpose - i_dwpose.min()) / (i_dwpose.max() - i_dwpose.min()) #Trying this in place of clip + i_dwpose = Image.fromarray((i_dwpose * 255).astype(np.uint8)) + i_dwpose = i_dwpose.convert('RGB') + + # if i_index == 0: + # print(f'i_frame is ({np.array(i_frame)})') + + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + + if frame_list: + # random_ref_frame = np.clip(numpyFrames[0], 0, 1) + random_ref_frame = (numpyFrames[0] - numpyFrames[0].min()) / (numpyFrames[0].max() - numpyFrames[0].min()) #Trying this in place of clip + random_ref_frame = Image.fromarray((random_ref_frame * 255).astype(np.uint8)) + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + # random_ref_dwpose = np.clip(pose_ref, 0, 1) + random_ref_dwpose = (pose_ref - pose_ref.min()) / (pose_ref.max() - pose_ref.min()) #Trying this in place of clip + random_ref_dwpose = Image.fromarray((random_ref_dwpose * 255).astype(np.uint8)) + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + + # Apply transforms + ref_frame = frame_list[0] + vit_frame = vit_transforms(ref_frame) + random_ref_frame_tmp = train_trans_pose(random_ref_frame) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) + + # Initialize tensors + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + + # Copy data to tensors + video_data[:len(frame_list), ...] = video_data_tmp + misc_data[:len(frame_list), ...] = misc_data_tmp + dwpose_data[:len(frame_list), ...] = dwpose_data_tmp + random_ref_frame_data[:, ...] = random_ref_frame_tmp + random_ref_dwpose_data[:, ...] = random_ref_dwpose_tmp + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data + + except Exception as e: + logging.info(f'Error reading video frame: {e}') + continue + + return None, None, None, None, None, None # Return default values if all attempts fail + + +def worker(gpu, steps, useFirstFrame, reference_image, ref_pose, pose_sequence, frame_interval, max_frames, resolution, cfg, cfg_update): + ''' + Inference worker for each gpu + ''' + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + cfg.gpu = gpu + cfg.seed = int(cfg.seed) + cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu + setup_seed(cfg.seed + cfg.rank) + + if not cfg.debug: + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + torch.backends.cudnn.benchmark = False + if not dist.is_initialized(): + dist.init_process_group(backend='gloo', world_size=cfg.world_size, rank=cfg.rank) + + # [Log] Save logging and make log dir + # log_dir = generalized_all_gather(cfg.log_dir)[0] + inf_name = osp.basename(cfg.cfg_file).split('.')[0] + # test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + + cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + os.makedirs(cfg.log_dir, exist_ok=True) + log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + cfg.log_file = log_file + reload(logging) + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[ + logging.FileHandler(filename=log_file), + logging.StreamHandler(stream=sys.stdout)]) + # logging.info(cfg) + logging.info(f"Running UniAnimate inference on gpu {gpu}") + + # [Diffusion] + diffusion = DIFFUSION.build(cfg.Diffusion) + + # [Data] Data Transform + train_trans = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std) + ]) + + train_trans_pose = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + ] + ) + + # Defines transformations for data to be fed into a Vision Transformer (ViT) model. + vit_transforms = T.Compose([ + data.Resize(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + # [Model] embedder + clip_encoder = EMBEDDER.build(cfg.embedder) + clip_encoder.model.to(gpu) + with torch.no_grad(): + _, _, zero_y = clip_encoder(text="") + + + # [Model] auotoencoder + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.cuda() + + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + # Here comes the UniAnimate model + # inferences folder + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools folder + parent_directory = os.path.dirname(current_directory) + # uniAnimate folder + root_directory = os.path.dirname(parent_directory) + unifiedModel = os.path.join(root_directory, 'checkpoints/unianimate_16f_32f_non_ema_223000.pth ') + state_dict = torch.load(unifiedModel, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + logging.info('Load model from {} with status {}'.format(unifiedModel, status)) + model = model.to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + print("Avoiding DistributedDataParallel to reduce memory usage") + model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + + + # Where the input image and pose images come in + test_list = cfg.test_list_path + num_videos = len(test_list) + logging.info(f'There are {num_videos} videos. with {cfg.round} times') + # test_list = [item for item in test_list for _ in range(cfg.round)] + test_list = [item for _ in range(cfg.round) for item in test_list] + + # for idx, file_path in enumerate(test_list): + + # You can start inputs here for any user interface + # Inputs will be ref_image_key, pose_seq_key, frame_interval, max_frames, resolution + # cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank) + setup_seed(manual_seed) + # logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + + # initialize reference_image, pose_sequence, frame_interval, max_frames, resolution_x, + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data = load_video_frames(reference_image, ref_pose, pose_sequence, train_trans, vit_transforms, train_trans_pose, max_frames, frame_interval, resolution) + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + # create a repeated version of the first frame across all frames and assign to image_local + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): # Disable gradient calculation + temporal_length = frames_num + # The encoder compresses the input data into a lower-dimensional latent representation, often called a "latent vector" or "encoding." + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() #use without affecting the gradients of the original model + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + + ### encode the video_data + # bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + # misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): + + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + # print(torch.get_default_dtype()) + + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + # logging.info(f"Current seed {cur_seed} ...") + + print(f"Number of frames to denoise: {frames_num}") + noise = torch.randn([1, 4, frames_num, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + # print(f"noise: {noise.shape}") + + + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + # print(f"offset_noise dtype: {offset_noise.dtype}") + # print(f' offset_noise is ({offset_noise})') + + + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 949), noise=noise) + + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + + if useFirstFrame: + partial_keys = [ + ['image', 'local_image', "dwpose"], + ] + print('Using First Frame Conditioning!') + + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + # print(f' noise_one is ({noise_one})') + + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=steps, + eta=0.0) + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() + video_data = 1. / cfg.scale_factor * video_data + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + # Check sth + + # print(f' video_data is of shape ({video_data.shape})') + # print(f' video_data is ({video_data})') + + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + video_data = extract_image_tensors(video_data.cpu(), cfg.mean, cfg.std) + + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + + return video_data + +@torch.no_grad() +def extract_image_tensors(video_tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): + # Unnormalize the video tensor + mean = torch.tensor(mean, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + std = torch.tensor(std, device=video_tensor.device).view(1, -1, 1, 1, 1) # ncfhw + video_tensor = video_tensor.mul_(std).add_(mean) # unnormalize back to [0,1] + video_tensor.clamp_(0, 1) + + images = rearrange(video_tensor, 'b c f h w -> b f h w c') + images = images.squeeze(0) + images_t = [] + for img in images: + img_array = np.array(img) # Convert PIL Image to numpy array + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0).float() # Convert to tensor and CHW format + img_tensor = img_tensor.permute(0, 2, 3, 1) + images_t.append(img_tensor) + + logging.info('Images data extracted!') + images_t = torch.cat(images_t, dim=0) + return images_t + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): + if use_fps_condition is True: + partial_keys.append('fps') + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] + return partial_model_kwargs \ No newline at end of file diff --git a/tools/inferences/inference_unianimate_long_entrance.py b/tools/inferences/inference_unianimate_long_entrance.py new file mode 100644 index 0000000..5e841f5 --- /dev/null +++ b/tools/inferences/inference_unianimate_long_entrance.py @@ -0,0 +1,508 @@ +''' +/* +*Copyright (c) 2021, Alibaba Group; +*Licensed under the Apache License, Version 2.0 (the "License"); +*you may not use this file except in compliance with the License. +*You may obtain a copy of the License at + +* http://www.apache.org/licenses/LICENSE-2.0 + +*Unless required by applicable law or agreed to in writing, software +*distributed under the License is distributed on an "AS IS" BASIS, +*WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +*See the License for the specific language governing permissions and +*limitations under the License. +*/ +''' + +import os +import re +import os.path as osp +import sys +sys.path.insert(0, '/'.join(osp.realpath(__file__).split('/')[:-4])) +import json +import math +import torch +# import pynvml +import logging +import cv2 +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch.cuda.amp as amp +from importlib import reload +import torch.distributed as dist +import torch.multiprocessing as mp +import random +from einops import rearrange +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.nn.parallel import DistributedDataParallel + +from ...utils import transforms as data +from ..modules.config import cfg +from ...utils.seed import setup_seed +from ...utils.multi_port import find_free_port +from ...utils.assign_cfg import assign_signle_cfg +from ...utils.distributed import generalized_all_gather, all_reduce +from ...utils.video_op import save_i2vgen_video, save_t2vhigen_video_safe, save_video_multiple_conditions_not_gif_horizontal_3col +from ...tools.modules.autoencoder import get_first_stage_encoding +from ...utils.registry_class import INFER_ENGINE, MODEL, EMBEDDER, AUTO_ENCODER, DIFFUSION +from copy import copy +import cv2 + + +@INFER_ENGINE.register_function() +def inference_unianimate_long_entrance(cfg_update, **kwargs): + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + if not 'MASTER_ADDR' in os.environ: + os.environ['MASTER_ADDR']='localhost' + os.environ['MASTER_PORT']= find_free_port() + cfg.pmi_rank = int(os.getenv('RANK', 0)) + cfg.pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) + + if cfg.debug: + cfg.gpus_per_machine = 1 + cfg.world_size = 1 + else: + cfg.gpus_per_machine = torch.cuda.device_count() + cfg.world_size = cfg.pmi_world_size * cfg.gpus_per_machine + + if cfg.world_size == 1: + worker(0, cfg, cfg_update) + else: + mp.spawn(worker, nprocs=cfg.gpus_per_machine, args=(cfg, cfg_update)) + return cfg + + +def make_masked_images(imgs, masks): + masked_imgs = [] + for i, mask in enumerate(masks): + # concatenation + masked_imgs.append(torch.cat([imgs[i] * (1 - mask), (1 - mask)], dim=1)) + return torch.stack(masked_imgs, dim=0) + +def load_video_frames(ref_image_path, pose_file_path, train_trans, vit_transforms, train_trans_pose, max_frames=32, frame_interval = 1, resolution=[512, 768], get_first_frame=True, vit_resolution=[224, 224]): + + for _ in range(5): + try: + dwpose_all = {} + frames_all = {} + for ii_index in sorted(os.listdir(pose_file_path)): + if ii_index != "ref_pose.jpg": + dwpose_all[ii_index] = Image.open(pose_file_path+"/"+ii_index) + frames_all[ii_index] = Image.fromarray(cv2.cvtColor(cv2.imread(ref_image_path),cv2.COLOR_BGR2RGB)) + # frames_all[ii_index] = Image.open(ref_image_path) + + pose_ref = Image.open(os.path.join(pose_file_path, "ref_pose.jpg")) + first_eq_ref = False + + # sample max_frames poses for video generation + stride = frame_interval + _total_frame_num = len(frames_all) + if max_frames == "None": + max_frames = (_total_frame_num-1)//frame_interval + 1 + cover_frame_num = (stride * (max_frames-1)+1) + if _total_frame_num < cover_frame_num: + print('_total_frame_num is smaller than cover_frame_num, the sampled frame interval is changed') + start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + end_frame = _total_frame_num + stride = max((_total_frame_num-1//(max_frames-1)),1) + end_frame = stride*max_frames + else: + start_frame = 0 # we set start_frame = 0 because the pose alignment is performed on the first frame + end_frame = start_frame + cover_frame_num + + frame_list = [] + dwpose_list = [] + random_ref_frame = frames_all[list(frames_all.keys())[0]] + if random_ref_frame.mode != 'RGB': + random_ref_frame = random_ref_frame.convert('RGB') + random_ref_dwpose = pose_ref + if random_ref_dwpose.mode != 'RGB': + random_ref_dwpose = random_ref_dwpose.convert('RGB') + for i_index in range(start_frame, end_frame, stride): + if i_index == start_frame and first_eq_ref: + i_key = list(frames_all.keys())[i_index] + i_frame = frames_all[i_key] + + if i_frame.mode != 'RGB': + i_frame = i_frame.convert('RGB') + i_dwpose = frames_pose_ref + if i_dwpose.mode != 'RGB': + i_dwpose = i_dwpose.convert('RGB') + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + else: + # added + if first_eq_ref: + i_index = i_index - stride + + i_key = list(frames_all.keys())[i_index] + i_frame = frames_all[i_key] + if i_frame.mode != 'RGB': + i_frame = i_frame.convert('RGB') + i_dwpose = dwpose_all[i_key] + if i_dwpose.mode != 'RGB': + i_dwpose = i_dwpose.convert('RGB') + frame_list.append(i_frame) + dwpose_list.append(i_dwpose) + have_frames = len(frame_list)>0 + middle_indix = 0 + if have_frames: + ref_frame = frame_list[middle_indix] + vit_frame = vit_transforms(ref_frame) + random_ref_frame_tmp = train_trans_pose(random_ref_frame) + random_ref_dwpose_tmp = train_trans_pose(random_ref_dwpose) + misc_data_tmp = torch.stack([train_trans_pose(ss) for ss in frame_list], dim=0) + video_data_tmp = torch.stack([train_trans(ss) for ss in frame_list], dim=0) + dwpose_data_tmp = torch.stack([train_trans_pose(ss) for ss in dwpose_list], dim=0) + + video_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + misc_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + random_ref_frame_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) # [32, 3, 512, 768] + random_ref_dwpose_data = torch.zeros(max_frames, 3, resolution[1], resolution[0]) + if have_frames: + video_data[:len(frame_list), ...] = video_data_tmp + misc_data[:len(frame_list), ...] = misc_data_tmp + dwpose_data[:len(frame_list), ...] = dwpose_data_tmp + random_ref_frame_data[:,...] = random_ref_frame_tmp + random_ref_dwpose_data[:,...] = random_ref_dwpose_tmp + + break + + except Exception as e: + logging.info('{} read video frame failed with error: {}'.format(pose_file_path, e)) + continue + + return vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames + + + +def worker(gpu, cfg, cfg_update): + ''' + Inference worker for each gpu + ''' + for k, v in cfg_update.items(): + if isinstance(v, dict) and k in cfg: + cfg[k].update(v) + else: + cfg[k] = v + + cfg.gpu = gpu + cfg.seed = int(cfg.seed) + cfg.rank = cfg.pmi_rank * cfg.gpus_per_machine + gpu + setup_seed(cfg.seed + cfg.rank) + + if not cfg.debug: + torch.cuda.set_device(gpu) + torch.backends.cudnn.benchmark = True + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + torch.backends.cudnn.benchmark = False + dist.init_process_group(backend='nccl', world_size=cfg.world_size, rank=cfg.rank) + + # [Log] Save logging and make log dir + log_dir = generalized_all_gather(cfg.log_dir)[0] + inf_name = osp.basename(cfg.cfg_file).split('.')[0] + test_model = osp.basename(cfg.test_model).split('.')[0].split('_')[-1] + + cfg.log_dir = osp.join(cfg.log_dir, '%s' % (inf_name)) + os.makedirs(cfg.log_dir, exist_ok=True) + log_file = osp.join(cfg.log_dir, 'log_%02d.txt' % (cfg.rank)) + cfg.log_file = log_file + reload(logging) + logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] %(levelname)s: %(message)s', + handlers=[ + logging.FileHandler(filename=log_file), + logging.StreamHandler(stream=sys.stdout)]) + logging.info(cfg) + logging.info(f"Running UniAnimate inference on gpu {gpu}") + + # [Diffusion] + diffusion = DIFFUSION.build(cfg.Diffusion) + + # [Data] Data Transform + train_trans = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + data.Normalize(mean=cfg.mean, std=cfg.std) + ]) + + train_trans_pose = data.Compose([ + data.Resize(cfg.resolution), + data.ToTensor(), + ] + ) + + vit_transforms = T.Compose([ + data.Resize(cfg.vit_resolution), + T.ToTensor(), + T.Normalize(mean=cfg.vit_mean, std=cfg.vit_std)]) + + # [Model] embedder + clip_encoder = EMBEDDER.build(cfg.embedder) + clip_encoder.model.to(gpu) + with torch.no_grad(): + _, _, zero_y = clip_encoder(text="") + + + # [Model] auotoencoder + autoencoder = AUTO_ENCODER.build(cfg.auto_encoder) + autoencoder.eval() # freeze + for param in autoencoder.parameters(): + param.requires_grad = False + autoencoder.cuda() + + # [Model] UNet + if "config" in cfg.UNet: + cfg.UNet["config"] = cfg + cfg.UNet["zero_y"] = zero_y + model = MODEL.build(cfg.UNet) + state_dict = torch.load(cfg.test_model, map_location='cpu') + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'step' in state_dict: + resume_step = state_dict['step'] + else: + resume_step = 0 + status = model.load_state_dict(state_dict, strict=True) + logging.info('Load model from {} with status {}'.format(cfg.test_model, status)) + model = model.to(gpu) + model.eval() + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + model.to(torch.float16) + else: + model = DistributedDataParallel(model, device_ids=[gpu]) if not cfg.debug else model + torch.cuda.empty_cache() + + + + test_list = cfg.test_list_path + num_videos = len(test_list) + logging.info(f'There are {num_videos} videos. with {cfg.round} times') + test_list = [item for _ in range(cfg.round) for item in test_list] + + for idx, file_path in enumerate(test_list): + cfg.frame_interval, ref_image_key, pose_seq_key = file_path[0], file_path[1], file_path[2] + + manual_seed = int(cfg.seed + cfg.rank + idx//num_videos) + setup_seed(manual_seed) + logging.info(f"[{idx}]/[{len(test_list)}] Begin to sample {ref_image_key}, pose sequence from {pose_seq_key} init seed {manual_seed} ...") + + + vit_frame, video_data, misc_data, dwpose_data, random_ref_frame_data, random_ref_dwpose_data, max_frames = load_video_frames(ref_image_key, pose_seq_key, train_trans, vit_transforms, train_trans_pose, max_frames=cfg.max_frames, frame_interval =cfg.frame_interval, resolution=cfg.resolution) + cfg.max_frames_new = max_frames + misc_data = misc_data.unsqueeze(0).to(gpu) + vit_frame = vit_frame.unsqueeze(0).to(gpu) + dwpose_data = dwpose_data.unsqueeze(0).to(gpu) + random_ref_frame_data = random_ref_frame_data.unsqueeze(0).to(gpu) + random_ref_dwpose_data = random_ref_dwpose_data.unsqueeze(0).to(gpu) + + ### save for visualization + misc_backups = copy(misc_data) + frames_num = misc_data.shape[1] + misc_backups = rearrange(misc_backups, 'b f c h w -> b c f h w') + mv_data_video = [] + + + ### local image (first frame) + image_local = [] + if 'local_image' in cfg.video_compositions: + frames_num = misc_data.shape[1] + bs_vd_local = misc_data.shape[0] + image_local = misc_data[:,:1].clone().repeat(1,frames_num,1,1,1) + image_local_clone = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + image_local = rearrange(image_local, 'b f c h w -> b c f h w', b = bs_vd_local) + if hasattr(cfg, "latent_local_image") and cfg.latent_local_image: + with torch.no_grad(): + temporal_length = frames_num + encoder_posterior = autoencoder.encode(video_data[:,0]) + local_image_data = get_first_stage_encoding(encoder_posterior).detach() + image_local = local_image_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + + + ### encode the video_data + bs_vd = misc_data.shape[0] + misc_data = rearrange(misc_data, 'b f c h w -> (b f) c h w') + misc_data_list = torch.chunk(misc_data, misc_data.shape[0]//cfg.chunk_size,dim=0) + + + with torch.no_grad(): + + random_ref_frame = [] + if 'randomref' in cfg.video_compositions: + random_ref_frame_clone = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + if hasattr(cfg, "latent_random_ref") and cfg.latent_random_ref: + + temporal_length = random_ref_frame_data.shape[1] + encoder_posterior = autoencoder.encode(random_ref_frame_data[:,0].sub(0.5).div_(0.5)) + random_ref_frame_data = get_first_stage_encoding(encoder_posterior).detach() + random_ref_frame_data = random_ref_frame_data.unsqueeze(1).repeat(1,temporal_length,1,1,1) # [10, 16, 4, 64, 40] + + random_ref_frame = rearrange(random_ref_frame_data, 'b f c h w -> b c f h w') + + + if 'dwpose' in cfg.video_compositions: + bs_vd_local = dwpose_data.shape[0] + dwpose_data_clone = rearrange(dwpose_data.clone(), 'b f c h w -> b c f h w', b = bs_vd_local) + if 'randomref_pose' in cfg.video_compositions: + dwpose_data = torch.cat([random_ref_dwpose_data[:,:1], dwpose_data], dim=1) + dwpose_data = rearrange(dwpose_data, 'b f c h w -> b c f h w', b = bs_vd_local) + + + y_visual = [] + if 'image' in cfg.video_compositions: + with torch.no_grad(): + vit_frame = vit_frame.squeeze(1) + y_visual = clip_encoder.encode_image(vit_frame).unsqueeze(1) # [60, 1024] + y_visual0 = y_visual.clone() + + + with amp.autocast(enabled=True): + # pynvml.nvmlInit() + # handle=pynvml.nvmlDeviceGetHandleByIndex(0) + # meminfo=pynvml.nvmlDeviceGetMemoryInfo(handle) + cur_seed = torch.initial_seed() + logging.info(f"Current seed {cur_seed} ..., cfg.max_frames_new: {cfg.max_frames_new} ....") + + noise = torch.randn([1, 4, cfg.max_frames_new, int(cfg.resolution[1]/cfg.scale), int(cfg.resolution[0]/cfg.scale)]) + noise = noise.to(gpu) + + # add a noise prior + noise = diffusion.q_sample(random_ref_frame.clone(), getattr(cfg, "noise_prior_value", 939), noise=noise) + + if hasattr(cfg.Diffusion, "noise_strength"): + b, c, f, _, _= noise.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=noise.device) + noise = noise + cfg.Diffusion.noise_strength * offset_noise + + # construct model inputs (CFG) + full_model_kwargs=[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local[:], + 'image': None if len(y_visual) == 0 else y_visual0[:], + 'dwpose': None if len(dwpose_data) == 0 else dwpose_data[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame[:], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + # for visualization + full_model_kwargs_vis =[{ + 'y': None, + "local_image": None if len(image_local) == 0 else image_local_clone[:], + 'image': None, + 'dwpose': None if len(dwpose_data_clone) == 0 else dwpose_data_clone[:], + 'randomref': None if len(random_ref_frame) == 0 else random_ref_frame_clone[:, :3], + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + + partial_keys = [ + ['image', 'randomref', "dwpose"], + ] + if hasattr(cfg, "partial_keys") and cfg.partial_keys: + partial_keys = cfg.partial_keys + + for partial_keys_one in partial_keys: + model_kwargs_one = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs, + use_fps_condition = cfg.use_fps_condition) + model_kwargs_one_vis = prepare_model_kwargs(partial_keys = partial_keys_one, + full_model_kwargs = full_model_kwargs_vis, + use_fps_condition = cfg.use_fps_condition) + noise_one = noise + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + clip_encoder.cpu() # add this line + autoencoder.cpu() # add this line + torch.cuda.empty_cache() # add this line + + video_data = diffusion.ddim_sample_loop( + noise=noise_one, + context_size=cfg.context_size, + context_stride=cfg.context_stride, + context_overlap=cfg.context_overlap, + model=model.eval(), + model_kwargs=model_kwargs_one, + guide_scale=cfg.guide_scale, + ddim_timesteps=cfg.ddim_timesteps, + eta=0.0, + context_batch_size=getattr(cfg, "context_batch_size", 1) + ) + + if hasattr(cfg, "CPU_CLIP_VAE") and cfg.CPU_CLIP_VAE: + # if run forward of autoencoder or clip_encoder second times, load them again + clip_encoder.cuda() + autoencoder.cuda() + + + video_data = 1. / cfg.scale_factor * video_data # [1, 4, h, w] + video_data = rearrange(video_data, 'b c f h w -> (b f) c h w') + chunk_size = min(cfg.decoder_bs, video_data.shape[0]) + video_data_list = torch.chunk(video_data, video_data.shape[0]//chunk_size, dim=0) + decode_data = [] + for vd_data in video_data_list: + gen_frames = autoencoder.decode(vd_data) + decode_data.append(gen_frames) + video_data = torch.cat(decode_data, dim=0) + video_data = rearrange(video_data, '(b f) c h w -> b c f h w', b = cfg.batch_size).float() + + text_size = cfg.resolution[-1] + cap_name = re.sub(r'[^\w\s]', '', ref_image_key.split("/")[-1].split('.')[0]) # .replace(' ', '_') + name = f'seed_{cur_seed}' + for ii in partial_keys_one: + name = name + "_" + ii + file_name = f'rank_{cfg.world_size:02d}_{cfg.rank:02d}_{idx:02d}_{name}_{cap_name}_{cfg.resolution[1]}x{cfg.resolution[0]}.mp4' + local_path = os.path.join(cfg.log_dir, f'{file_name}') + os.makedirs(os.path.dirname(local_path), exist_ok=True) + captions = "human" + del model_kwargs_one_vis[0][list(model_kwargs_one_vis[0].keys())[0]] + del model_kwargs_one_vis[1][list(model_kwargs_one_vis[1].keys())[0]] + + save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_data.cpu(), model_kwargs_one_vis, misc_backups, + cfg.mean, cfg.std, nrow=1, save_fps=cfg.save_fps) + + # try: + # save_t2vhigen_video_safe(local_path, video_data.cpu(), captions, cfg.mean, cfg.std, text_size) + # logging.info('Save video to dir %s:' % (local_path)) + # except Exception as e: + # logging.info(f'Step: save text or video error with {e}') + + logging.info('Congratulations! The inference is completed!') + # synchronize to finish some processes + if not cfg.debug: + torch.cuda.synchronize() + dist.barrier() + +def prepare_model_kwargs(partial_keys, full_model_kwargs, use_fps_condition=False): + + if use_fps_condition is True: + partial_keys.append('fps') + + partial_model_kwargs = [{}, {}] + for partial_key in partial_keys: + partial_model_kwargs[0][partial_key] = full_model_kwargs[0][partial_key] + partial_model_kwargs[1][partial_key] = full_model_kwargs[1][partial_key] + + return partial_model_kwargs diff --git a/tools/modules/__init__.py b/tools/modules/__init__.py new file mode 100644 index 0000000..db82a43 --- /dev/null +++ b/tools/modules/__init__.py @@ -0,0 +1,7 @@ +from .clip_embedder import FrozenOpenCLIPEmbedder +from .autoencoder import DiagonalGaussianDistribution, AutoencoderKL +from .clip_embedder import * +from .autoencoder import * +from .unet import * +from .diffusions import * +from .embedding_manager import * diff --git a/tools/modules/__pycache__/__init__.cpython-310.pyc b/tools/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..550286961e2e242d1f42613892457173415c90da GIT binary patch literal 438 zcmYjNOHRWu5VeyOqS7GPu;Kz?!2u9LR4t#13Xq_iD3ayYPBqA%$aYkfHCxu4fh%Rp ziYu_fra>^4-}8H(8IP>9tdD%WyuMXEjL=6A|0D9^+;4mBfgy$klDNP)oFfs*xQJy^ zBtJS9=rQ=Qh&t|dyw&lx=i~w9y-)uZ%5i&eY4kH!cQsd|>E!-G&N*Y;LmbTd6)PY(ONfljk^Nvg!3Icq==%k!kZfpq_o^(0VvRcn3WCe=pl_j+{ zj?M-Dy(kxrby^bTXSBa2nJE}5kks+gX)SCiHEV=VAOsXRLXPWoo1Sjmv+zrH*TVqq J;j~)={0%uweAoa0 literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/__init__.cpython-39.pyc b/tools/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4c514b48df2ae693a08da97bb7c20f6dc69e63a GIT binary patch literal 382 zcmYjMyH3L}6t(jxQE4PtnD~J(u&^P7s9Ij40wm~SMT*?Ui7xUZ@?(IRm6>ngm$EYP z3ryHF5?ssY_#EAPY|}g+BVVs??@fvk`V8WKMP8iwZLU1jsKy)#OfU}Th({tOu}Dbr zt7DEHgCFx~;B?^8z+0Y^2UMhA{w`GDe(yr77bfpICTEN5`*X2kbEBl(&`SFs^QhsTp``q!#s@8~4hiIdEd zaScu5jE^*AcQXgFtpt^{W%>i@Q5XpDeKb2RmxgImf`KPhE+tGQ2x?hrmExups&<^2 cvZ_FWt*XQBxzA2E?rHe1nRuAs9h?p82haaz1^@s6 literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/autoencoder.cpython-310.pyc b/tools/modules/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87ec4adb3dd92f9976335f46500b40709989b970 GIT binary patch literal 16687 zcmbt*36NarUEjO<_3NIVo}Qi=&5Wd#Mtk>~m9sL&Ny@c{rPW5*%)+h~LPt)gr{9}V zxB3`;?~Sw?zeyHL1}YU{a4<(Op~pbQ8tqcW>w5qG?(=_t-|NlJ<_$cr{rVR+zp!Z-zr)D*lR@NJykXlk4Bzn0 zwo%ppX4RCqRkh^buG;eNR5SALcCyv1smpqutYw%+Cs)lQ#q#a;Os7ySbY`owh&z7f zmQgMG5Bgc(yJc4A0<&M*H~feEygzfxylGbFgF?T&Z@ggm1%LLI;m-!y*1}C2Ig7)b zqCbb6xnNP!mr%Ck&m(0%Fk4c(dL)=lTPgbs$XN)MrR)mIF8WJIS<>xRSCMkWUq;Gu znsO8=EB-1{R@0PYNI7~>$}#^q${bJ297mZ4?rHgie-dR*rez*LnKl0uQck5QCw%i| zqw??_?!)>W8?a&}h38(m_+st(^|LRYd;X=QaPGp(7hk&YkykEWcyYaftrsHYm2lWf@B>hzj^#Z1gs4-Az<@i*IN|Lmn#!a#+W znqjlmyL7hK*}U<}h1&II*Y8~qYuh~))i>I~AmdVFCyaWXTDRv1;UKa0O1Jr3x7n#j z!KJ9zYloLQJ%6W-0`;A!7jzpatWIy=NDAFvx83Xpb=8b+6gXxJplGQX{M}gtXfR>$ zFzn)$*)1Xv87;GA#h9OO60vSOSUP8Y>-O*cbpH>&@};LQR-7c?-Rab}RBvZHM1Kf3 zF4w!=pdBW5fF!rRxrq@}Y{ixntJ_U-FR=7_Pj#?ZGtKU>hMGkYStkuhJV}lP>+R-l zP(oBDuxKid?yGnI-Y5Rwm%jOzo|%{gcXYT#=MjmxN#;#U&BjLG+Bfz%A&c83Id;oI z-fBFap7G7djg6b;kAR+j6qN8uvtqCB9!r-NyA(!&YWCDxZ*wiWu^p^Ew!8G$TDKRi z-3X$!W@oz{bb@Xa_?2vu$4+;{sNQV^$?}9ssJDaKX1xJ4Ow8Ts^<3Es zqS|Itg;5QI+6rpYM6%ppMXWLkI}L>?J8LDu`=s&OthYF3;9v?IsQHI3N(3G;hE617a^ zPQhKxY5{ShHXFOXv16)l zjf$~*(~Rb>Sm8IaMPP5q&qQ8--gi;PK`m$B!U(fc{^M~Lwac-`7(xp%BnH~~YtdqB zseeS^asnnU zIoVFuNV3msVo0(nRV3MSf$+j4>+67!mYU}f-Lt*!wKIN_zq(WJMzscF$s7s0HkObJ zIvass^P3%YjFoK6AemFa7RWYGwT+$4O^hw0>RrE^WIH{K5JZ>V2|-q!dSdm$#M!P# z5C(ozb?coV$<{YQNx_6Rn^?kH;+(i90zu-i3z3a;$FF4S6g8+ z-`G(K3=1<=0U5`Qq_kb1N-3&dZ>+?DndEA6#_{ z3g!v(2*Mu1D}b&EK4UFf1^kL;*>wJJ_s9gPo*wLYgDidUkns%J`z|7sQ6xOyx(x}n zYV6sO72AD#AMAV2iJX20EI$)F=ZudZ*#qJ9K|1@UpShAz??Wu>yNG2`FAG$6R}GMQ z_QS^2B_O{S3!5VHMk-v&lD>cYqKV60JqwQNMm z{5Y$Hm=4;C6tY{sC;jxDyLc|5HJ37>5$zUV3LyMd;I9!a)+#evGC1vC-A|l0a4>O5 zD#;ACzm0u^5JBWc1@RyF=h(di;p>Bjw%bi0E-RAj4+fV75p99>Cr-1w*{gW!84OMR z83N)~VuikX8G))B`n7IwJu$CTGglazk}B$Hl=^dcLrRe{HX9Yb?(VKl4)@*)6G{gM!Xck!D{BB4ECdBuDp)#( z4V=+ckZn;u{+$> zAB&_=${L!P|(9(2-$E>55kmwAXrqh%)Q$CfI9DMO>$bGi;f z)WaRj0NMreGe9YO6U(eA+ED?TQY8JvAz3SqCS7zWaNX^ls9Ms>@<{U?L#ivRKD7zi z&7Oj=^A+-B3y4VA0*l+QyEM54Y1xVCaq&4tq48%Cpcx_Z#)$0}3~ArIVyQW(T^gS- zN#BNCf&sqow4gk;GQNA;?7Nt{Y)~b;F6vRmI=wDJK|PFC)q?;K0<)c<136s7P+;i6 zjGjFcE|8`D1agJo$iY2`ZEXJ(iVn+8Fm4JqzOq83STb%}kb{l21rMnmcCXcXJ5l0Z z51Lz-qXSc8=G2xI*$wmyf*XPQB*#U4GJlUWnJn1gRt+BN^qjBI%?gG`-5wZ6#zxAq zzyzokP}?Ogy0I00XKYaSzGUrX_q^D}#zC+5tt*Z?6K6-U@bOeHM}4Vj^xZg%5H#_A z7Mgl4_WJpK$gF)6s)P+zlEuG+e>;2u{mm%!)GzECpZ);!_(7k!z5GO8&O8MT4TJs)$X4!(3}N z0!2zp+^APmTN^^*8k6&)P1oAZD}k_pRJ>D_4I4d}y)znFuCw|K92_?^+J;rCGaP`( ze9%j8OEA6ayoyl@d#HL0%E)JjB2}tPpq1X0fEWrOi!#_oUi4W0 z4^mbTKK@6Z2ZC9dWGJKNCYU;xnmr4}aB)a6-~s{a^X;nYPe>%%Y`9wsN6(rYyOV2Cm(^J2w1f4;In zbCKp2tgC#jroCe{9cwi1+B!!J74;IwN$pkD2t@E~G1MfW+)+IO!Qq%8VYk_C4^FST zufCH>9Bne28bghs-5wZ3f>jq-gzQkA1mG*^|D#2ZooIE5$7 z-PP$yjjT%2-r$k)h=`Clf&l68ekzqIQcSf_Y#MEeFF%7<22i%91gEU(X{|a!dysb^qeJfS2zh<$ zbEbF=ZQr?-2f`9@iLgAq0K%q->-sZ(0r9LLum=R5P0AO1a4>kf&(pj4GM$^};iM>B zy9xmTcMV*k!BAf@(JO~QESSS9hAX*ejm~0m6l;vtDygl05QWC@Hr1si^)E8(c>-CP z(Kc`ye+jRk0}adFWA~2n{us$I(&l)B10|=HvB*XYba@XsYSjN|k7RqedH<3PaF#%T z)kAvMC@_JJ>~7Gv`XNL{{i+{kbgCO-{jV~e&fr5kIRxj~313cd+{%Y18D3=b^X4XO z7!gm_I;^ktb@c-%d0796CmuU^8-Igskw1$OfB0$-%KnLF+(RF~$>GS1j~=qDsRdEV z9ZFh4Ra`dpBLr;j{zUd!Z(&jw*J=N`f$BYBp^35JAKmVzi z2Z9dV?D{}3gXPX)xlMf_2$#GZ1HvUwxlj@&bAG}SVhgF4P+pt;e{oN2V75GS)f3qpf|^R3yy5!b8Ri5*6MVzW>l z4)0#5Raa_rx(n@&Qhg@;B*7HU@;ZozdyZgwy9U6_e2Pt;& zklu$o^aLVYIlN|MiJ>EF+~Tgdt*md}#_8qqKpV8y)sLYDyn&y`pZZ$_q^T*rPS58x zW_%C9)TC&Jxb~G(7R4U8x`vlRLNPeXTh`2Ai!Cu95BMdQ4_=?83YbF4D|?f zV-q&E*@sgB+OfS4>}$Ep(2-YI3%V|_Jo9Z?Zf!>J>w3nJxDiYe08qXuJia zWS6JffeoLc6qev`Vu%_CHxYt&kEo~J5cLzo-ez z$oP+VQvri>nS9CNE}+Opky&p7MT2V z1Q!TK3rG&`q1QJS5y%5?aLi>5EIS;ezw_q8X#OLP&3zs`Tma(AS-fwY2O!d@Wms3_ zY*q^kK>Lca6_a;hxf_tGG!evxPs@O7L1Ct!(bnk)*>w;B7eq6oVMv^CI11`Uk_%=- z^hh$->g}CCmJvz+g!TR@09H%p%@OM3)BNU6R~$n1HkcO-;V&aA$=9!K!6VyipYaEa z=CR7ZAebaLQF<1*XhmeYU|vVStQ_1paLhN$W6PClQU~s#IR5Uc(?3MhxF%)1!GU}5 zx|e1Y(;9du9`ZLnV$u$9dVp)6vjw|piO?d^mdFUZbIA-2>bgPn5u5;CSwG*TQBkpz z0z90+4a=8-;^ zw!>b>+bNMU@!D5xwT*vAZQQ5Fc~a?E9>)SVmW-x6py0}Yg2`0WKPDiqRu)ENs)YXj z4N{YN9%N4}sN4ujGPpj_d1{)7Gc5U3;(ko7LM1s_`fb%~>GEfAPeApy>5tC(X+Wx+ zQU8{m6}q^iG}ySo(a&%!QbzqJ4*1^@9KsqNz)Zq*WacfNg_Aw-j?I8UcyDgqEp8LQ{cH_ zilCOM9wRdKnGck}`8b@&p90bi;3m9}&=*+-hdEO%;V4I1qBA@VwW!s5K$LI6e@=Z2 z^>CJsQuGbD5~13^L{OHRFcgcs`HGs;a<{{>(m{>g%H zWeNnR59$tKL}$50YbYmYLSRIm^{8x;xOcH_NOAJ-c-HR1_@bx0N&OmXj`26a0THdV zNO_@1pQoU8%Pie3J*t>~4-*>focagIh0x4B7e!rRD)*?x1OD*FUq;bs{(vJ|%>%#< zsNkn^QyXUF0p5MmbN5Nl-zR;B)O+Bh=LV!#q1=PQG|^HnCJu>nnl{a=awL_Q*AkB; z$z!2|)G;hHb_hNF1@t7!0!8gKB@HNqRCQRm@Yb8<;ku5fW?U8%+Xhtg#{WVDR0CEA zs^K|4L!gTk$H5j6(npyL0sC9dC1=m&(GiYkY)EPC4b4F@|*~JIBmE^=*c=5dcb7H(46cy!>`X2vI50!Xqdkh4h@GKtHV z*c+XW`T(P%Q2mM&3wxUpod4>ph#jOyIYOY;<=RxkiMf%O4WZe@yq;Ls(`!aTzd~M= z{yW45j+LaVe?^<-Gq0hP0LwPdZAx0jg{|1=g0OJU%zQ!WIOYnOH|Cc~kXmXjp$mafl znSV%dUyk)UazD+oG=|=5riu5;_Iqgc;U8P8A805;tsYm6N2gF)C-_;8{htwBAQ(;Y zaYTO+!xleos)CP6PPkF#2SYua=($DIzRuze*_yum+yXDv5P|cGY45eIT1U1~a_}%bMBsmfi250VcioQPLeJ{wS?r$!Odv3u zHmJ{uRL?2>jb@?q4(WR|)Qmz|`-5lVyJ$fd<}*hm+tVXgQ_mi*M9- zKj(_NPc@(F{)`+sm}BN_LA?(xzy(jPBev2V2cbLMCUIe0JfCE*)w|7hyIygVEN+-} zcG{6>EJ+C{0#71t{0~gdQ>dn|X6n28HJqp7mONpN_q>P3!7K{WrYGwo3kZJ#(;O52 z1(k!nI16zPBLNqV<;pUTju65PNI+MvGxKbYN?-&=LjvXAr5MFR#|o)-FWE!U_bRdVZaB#ZT-URWRBlm0Izy&`H&z?ifKY^gC=R1 zf776 zIVrlVh`hwrMRC`N-Qof!P*`|Ll4nINDR=?EPqI>^2CGc){Ha008EwdU)D@`0fKLv0)2N%7B@&t&9;Fz#$x z1pIk#Ic`C6XaagrSa_*~#ve;~Z#$@2#9dllp9&4*d8yA5`L_!W|8*F52UuEug|~J~ z&!^u6-S;7$Yt@hX$Q4;A8vMB{RFbNUEwxO*3ukJT;3&Z{0@{P}1(wF` z^wu?nM|B35`*-kh!)8#A)PH4rmkCa>#=`^~j2ryk4oT#bqByM|)7OTdVz3&|g;ev9(t}5k08qS9>U3p?P?z##iX! zRK}_dc2X0p%g|6g(^dpUt)e%$^2kVK z9n|5#7hMi4o`7z;yQzN?zC|65(+eB42h0T9cP&1 z4ss&g`9*FnC9V%aRx+OVp+pM-of3HopF8AbLCZ^#l@3&G+y<8PVk#@Yd6c3{@&?$| z_rrTn!uj|hw?C&xa`jQ9lvG3Hm(icssZm)z+Q69m3&RmFsv8l$2u$Prq}Lwv;8m8j zKc^rUuLr0ulJ78KR^jGPi9|t&DUtXD%L&d+ia$+PR7c^V1=9v+m$SCvM+#aDirgEC z!9yhAXNhg6(T-<1cHS=H-XvT=z_KGE38^`iAPHS)OmrEF>ktBLcW!_o0~vT<@kdGY zP3(md1ee`J{K235cIgF?Qn-X#Ko0(v_;*tMY5Zv>@$C#<1i*>IPG|du`Vzw)F7G!k zpYC?WJy2a3L{E>-*3_@C=3fPmu z;irnav~Z`B%>I6Y&$5i$+i2n2F@2ek#`8#_7akcUxS?+6kdJc`Gk+jn8+G{ypJDw00yo3P3*lT;e!HJP6dkth*GR#i5q#2E7 zv|qP&y`lvayCDg%kP0bMfPqm+6;Vh*2xoByD1y0fe)JDWRVb3>CO`m_#5~XY^_rj{Ijf-KuNqGTBYnGE8IBtL7xdYUei#)xu`6T14FTo$E$*%D>Zh{p@wKS_+Kb z^q%3r%+LAx>*h7HIujInvwOyC4Zq+QuN!_ba9ibTHge_$Ia7WKIi+Au(&tfj+Mhwn zOklL6baf#prmf8SW#p8DBT{w|W#{~Pq|EDft4l~(@Q)znNSbmKDU1FRQkK$`V@NrA zAmx~U9A%EDWsalFZ3kMu-9LdcC(<&vq0AlrGE$b)l-qstDWh`NO|Jh+#Y&0~J$?R( z+M_FvJaP8XCzImY$DTU>59)7-2_!G);X zZHE^&yZ&}N2rtyPqi)b?bp1e`+Paz)JKavZ*$L{Z8C@-~cMHI<6mjb2GC)HT5a7eG zi&w@uk3eL!%$5~nq`v7}zJ0^N{GFAx8-M@Py?_7E`yV)8aguyzd$YEty4zbJ`a`&} zUhi~*c9_@!lHB_0Dn?MT6)`ojI-Ml2wz^>ysAg9!cUPCAt6RbH%Xb!D zzTD|X%U6SFxw*O34mN{M6!?{FlE;2^!l>S91j%h@oAtGBr`~>TeLD=B_0HL57^&vP z?Wo!9ByLz|O}$F)`JfYaRWi5IjUF2|_^47ybl7RduFf79QBn*W^>$EOtv7&7iMcbm zmP>0vR9kJTFsfltYe7w#Xm-}PoFP}=u<~Z!ESj!0Z>rl7Z!mH6Awu#AzYicV$Yj7^ z%r&eUkc2^=8_?hdusvS7j7}HE7<9mJ%Y{m8SegXek|AbE#VOBlO7^*p6 z*g!Q8@|*V;kW=uF_=_l8^q2gj_@4?&!ibiAlC>}*kn`=xnSC=hzgotg-I+{1G^B>t zKfaP=Krw5pNYKQtnZ?XgIdm>g%M_@{LjZ()Go&!V+hT~wxlPq(VCC# z*zvPBP}(OZdn<|OcegJBOLlIO-jn>!fc|chQ6*lZ+wLmR+f9-Z_)wv_*6FICwh>$n zD+@^`+};XQ;?-)+PBW_2l42VSrc-No*Dlvpl9N?t;UxR0W<^OhMb#vGHW22NWPKeF zd{;B<*L|ehxqRAB@|U*jov7A8ESV;?)kYF>!RE!luldbQb&QqlW+zIfRj>wH3RLal z_UbCemQnSN-$}BYU5pTvklhY}>znnYBtu?RU9hQ*t*BB=oG@swVoiMXx7u7&r`c+L zO9hdtgKzoj9wew&Fm~4-9YDiJV>rj4=~CXN|9W% z@Rzx-%awxWp-#J7_YXZc@PhwSxvbN8||wAsP5*(L~}vMEOAMt#%XG#ERsP z>EMzqFnbM5GI5%n)o#U8UxJ|}&Q?8I7wkx^&{vNkt?Gt;trJ{H%#CVhgP{oqtiBkf z&fpEp0ESt{9((we(%;T)<9j=ftz$X`cbo4Mctgr+?1Z6az$(B3`fIO<0~Jes5&n{# ztSiX?!dOp%XsI4V=EOLYa<@|x0;pZAM~!vJ3IFvd7P379D7mvdKG1{3$7C`^KL`5= z-f$TJ!vr_oGk3xFFkm2yMo1!K3?)MbzHh5nqOAISf+GO?85y_ABAl4el{UMSv|*AP z2<6F%xSI3GDv9KRbjxGKcjm_@po@JJ4deNz2gits-3>-hPB0ifI66jeWAqNl$-(F| zkqy$Yu303abRHFFVjdz`E7{>>)N!<`-at?#_*#P72sma9L>VgUaf%FAT82N35~0`# z>D2NEcIRNFvE8A{H((aze|mntxo2U2EZ-u#I5<-<0_5Fq#Fl7C=hdSe>sJF*+~L#` zYdxHt*~FxB>G(Fo3#hEs5Fre16t-&du&t_XhGUc91o#S?E+a75B+vtS->L}#Zc%_^ zGIoD!EOj5ss#5^_jvlqFJBcE(!NDhYz|~ycM}EkKGqCFA1L&j+ogYIZmL8zdaG8nv zv89&4jiGkzIvo!os;LfU0O5`N3{c8m#WHJpb`;l1EdnGCiCQ%fOcxzmPG@U7s?KPA za;UD3Ak{@ypV)-#YF9x_`7Pw7TaY&$lM)9AwJ^Q~Y1y&qaq&4tq46^aPo5b8Wl39i~;1%qkdt z1mYB4m=Q*&o?;gy_{sIjQFWH34UwA?967fKDLsjzgR)~Rn-Yz$tPm-dESm;ae`9UI zLn>gM%eC%yl(<)d=GuC+Z)(h(*s>D4fjU8OHBc1usd+RhN5%^_xK;g!iZ9z0x;cU2 zQ2_?Vk+G3-EHDA;17H9x;-VW{Z!^PkY*0nLVC`miz1YRpLH+fu4M(lU*mSFtkzSu0^qNCoP~+3S@*{m^&_@n^x%V4i4b|Ae@0x8Io$`GK=I4Drc8b3@ z;Ei3BgSHcAM1j_5NX&$>E=F5eg$9aEgF>>c)r)E&&7ruh&Q0WsD(Y%vs!SygRUp+x zgQ91u4K@uvbGg}op(&B5TT^Qnh0ZlT=S9P8=UBSDj}tN?{UJk3kvv)IhLGl?epQvnf&c!(H?fQZ^cv3!;Dl ziMM2xMJ7>sEh5xUSu*b>Utbz$D#OOc7(0y=%08i*2mw+Jc}Z9jYYRx~L*9n>W1At5 zBD?4Kc5DJ29H6AwNTeaL5GT|E+HY+AB;2V>kz9HwcU_iuZQ(0QtDN#Lw2OTkbqmO zH`_dyxFq`1zOk=0DBrJuK2ZWJE9t=?iWuC8C z=gw8;Xux&qn*j{Ce66M(NHpeZG`HG1M{E=IW0_;WPn4ER1<&!^eAe>- zb6(c7foeyk9)~jaux&o!0|2-1RXC}y0<3_(3P0<6@FO_l;xwum@h3o&s}z!%5q$ls zrXO@y)E#K|{6UWbnJV9qdW8=bM5an51a>||I>b5HVC=>DIJcUIT2i1|lA24_3N}c+ z92I5tfkc!~>NNz^8q}8(+(STmQ}+|l5LI7Bz?q62s`dIajLH(r>fVd!yYOm7zUY`c zM<*vWG$=`V{f8ohuZRii;-tb`DOIM(F;z#g2{9_z@)Jnc0|0**6P>cEr?m1k#so+- zHeO)@#}ELqqkC|7D5x+UpoEW^4p>1k9XL|;H3YH-!-XIc{|w%cNzmaeXHwJDW6`Q9 zboDo(nEFm7iSoi1^+gLdU!*Qg?b8yKpdFHt4&OB&KiBG;z!l^s#Q{3-9VuctW9-kNu`ByOKkiVHZGH; zS$X{67;lY_9HSU^tWGX1W07s>>+-@khy4$CNH&LC_v`FmFdejTNbec}#*mTSy-45c zJCQ$v4AK?%_d_qy)ggG#uK3C{mk=P?!I2H8D;%!oyUbNsG9sX?6?myvR@8T(&|xD; zJn^Ezt@m3fk(w?P)rYQmzwBo`ZwiLrXNTz$A3J1;Q){A`H>F};Dc%4RcAh`Q$GA4438@agL&k)>I%vN*hHwMd#16z z-PwTs8Psuj0o{$h-%O(og*ch=Sr9r5d~eML&VioUO6)N56PtzdJAvwkT6Nt(>$(f= zl2Y$r!uJvkw^w%dUZnjQ-jKdxx}cU|cDr~>t*md}z){^|UmLVm)OVu>T!G(%KlS|tY->WV)3f0Us_O(3Q(8pYU*b*K z-#zBeT`xKfZBnd^(IjW#;uQ?zU^oj^1}@+;b0%hI$=zV-t3^*@I&N+OfR{ zylc6O(2LnZrqn6onkFy4JXNE>wKuUITvK``Gb|pnB zEW_I|OpS$9Na5wXYmjg^LvI7X(h3&Hs*>>mfwngjd#igzulchq`A(M9DsqzPM(cr^ zSmVUnXnzESLi!MiHAVO{%BJ%Iv`W1Ax7<&A4zVdeOw;L~+PULJw_~HNeh7_@5j=Ty z7O!l_lL%m$CbmNi6qqTk42gLchh^9n7iuB4U=ggLml`NoY#j0+HP0o6s|^N=97eh; z>Ic!X`bL7`TFR;wCRfesY^iTR!Nl5Xsvc`_&D0)&;4arO^?{@wC9s6B>Vcx8D_v9? zN7}^hwA=rTI&zLJD+-&#HJ456(>u#Ax?ZR_P%qGch)@!wtQB#XYt?s!n0j;y$rB?` zKGOaZZ(pR``J%njKUOi%RdB}c9V%Z)h3bb1euRLnkMtFh#LkO?Ml;B$3@?_lLA}^IH#|L&Gg_AeInCwBg4O2%Bo3`sP_URNWhVbP!5M<#(%p+Fw>v%ckP|M@ zVQSfZ^7*;p%!j;M6b4qr!>JFr+*w43D(aQ6tjN=>7N$;{AF*E}jgv0@d^Pf&6xnM9viX?Nn-rf$h|It!E%6jhuz)H!S zg;Q~5lKkp+N8CX5HW(HR;P*>v{qh>TvEBA*zrSGIRrPZO<3uJ(PJs(nL?(&M!9idV zNAk~?N0utpoA%v9@p_f*-cnI|4{~3?E3yaPnSx?U1Mk@5q%kpM`*_^PuV-z+Xj&k& zL{h_ffYG_6fRnjS5Iu||zm1i1O&SvwJ1N4mIXH4oro`D9u6I?`fX8zx!U69fcfc>W zxI};f9sC6%BXmHeh({F^bhpK3hgKn8q{5X9=tMw=Yb@o`1=Pw3x=;b53Pw7~)9*lVA*}E`R^4|LW~15)L*qPd1^KDe~g-?IK;&tNQ4Mm@&?e~{n|fgt!m zky{{cKaO->H{oG7j=T{XvIH-~lOO8^rqadUXBA$Cp2cB2?+(K|^CRds@ZrU*)zYK2JDV?Iy;2jOrSzloRXQn(6FBXmQSaSpVE zqZ?_7uJ1I|qVDbiL7v5bPQlR)&0TXVXz)Nj2?Is##eqO9s$au+<>Hq5b%uU};6ntm z)W5~ha1VZ!QEry{ZGgUjC;8z15J*fP$;qw^w}vd^OSv^C5HP@WxFtLRx<@t#C#RM; zcCEa%GdH@}DJN3Dh-xFe>A!@~JX>aG9$L0un7KzW^#IH1@7MoEbmkt4qK>ec z1IqD$V^Ce1&f<*GnS&pHHxI^|Zxvd_827Nm(+~DO zQ9l#sN7_dXw1zZ&<~=OC=g2&!(ivesq{)3K;aJ<$jJXS;HN1E}{3~LZf>eCV3Z1k`dOxB!|rN?BLMZsi0!AxImG6wwO*TO zI596KWYe|~nj_b*A_ zzl%1XWb0(G&?b&4rhGO?TQ@X{a}56}ZhOVK(wOxnHWGowQFr)YPKul>Iz zJ4FwW`g4MWK=6AU$=~(AL2~t%Y~_|xi%9(DQMCSSV07X=z4ckM^7`7VSb4aaJv@=r z8sm>~emrkE3@?9)Q={k*=`f~O?acNEc{t-I5t4H~f;YSuK;IN^L7r-sEk+i#-;)pm zsVOFYYK)PRILY!>Ixg8~E|C^i< z(+zVT^TEZHnZZ0>l8Jp7)qZhgxATfApN%cowv~<6`VYecCUzQ;5zT~F2Q7$~u<6x4 zs;L*<;6Iqy;1Qc!eE_w`moxyTrxiBihe|;VPbe-+k5<^Oa^R&DVk)otYypnWFj}w1A_eyYMG= z|4qubn=*+XMCkmvBzw8uX|~(-ikoC{PjYj+9f>}k%mBIJ3$G1bC(EdzF9+*SGHSRJ zfR7;vYrN7wFm|Uf1JOD)qS9k$kxx3C2 zW~#6`>Ou^+#6QB*1u-*=ac+0occUrRh5a%U!+5QgqCD@`BQz7uqE6v%KNHivuZWVmiDawQxBIwO+k@Brf;L(Of(qFQ6@49mbgi3}{@cSwOAn zICI%hKNZiR%zQjg+sZU9{SL0-i?NJ_zUmDdmx4}4OK3BD%|hPM zc-r@l8N1WyVWx*V(#kjLRxpoa(m%fb==YD%vD$Inql9lew8ZVv362RZ+>x%~e9Q|; z@9=SQKX1Uvkf*5Aafrsk=aM`t zY6-`SEPj%eA~pCM1lRBDC!E$U3Xj$ftr@k&=`U$#;y-Y*u67lvk0PSzbxvlaiQyDO zy>K660a?+HGxRcoWrCXoe*;jR>VL%`m+|4_+g6GTss0VQNw!&Q2bXc}{J&V?lT7mT zu(47(9TI;?Q?Il7ca}RtAnZV_+ioP@9EM%?GNVnlPwvb-ntuFs%dY~+ z7$3nK<}pxTJ1t_I8-HLN zxBTgBqwbHcLJsDJ+L`X-Myd(nROHqYmXhxy7(sfFPQ-IRJ@b%A!~?>XW8-s} zuqAthz=62Uj)_}-K{|T=QBa?Y7Su-wv}{#>jWBNB2=PD1^p61~#kJ<;poR~0_)$XQ z!A-v|%ws_5136kk@uPSLa`ZV&Xhe=qi5xBcPvvNgQ+1rA7XiWdqryOHQugsg6|zaB z;cSfCh5eK{1CpuiJGkprj?1f*&c$e|RccKm1>bYY@?=FG&qSpjKBh$pO3Yc2c$Apl z4b*CtYxB`UJWF|m@R37HETV-aY2j#eEIQ6HT!VuP1pMCJIi0sVpZ+h1%$}1GACa;c z%OXApLh08)_RjYx6=&lkXeAwm$n7J$vzQ6wHl=X9fKoCl=Dp3lbW+b@m(qEub~~j0 zQtd=M2c|%&{r2b%v`DGFR9lYAK8`?lalN;7w09@;n58|C%B4?4FVnNE-IcD>EWEIz z>vS@ev+8e&6xdTk{T%_RRz#?_f+%j)-vd1u6 zMUiTD2fl9HDj%^;C`dJTc%C7ldO9;T>CFA& zHWCkC(aM$L1lNP{!EYoYQnL`@9!H1LQP~{gKr||c(E*)IO2gr#282>AYLTcvPZK$) zvTP(nAWUdDJdbI@lkjEup_1<$dD1hO>;cI~IR-9OrmM*rx);6dtz}vnJd1LX|3j(B zPc8Z~l~Qp=I9%UdxifQCe{pxX%}U8fUFh|ALv9K#E0EaW&KK7`E;eWfWeZn`p+-~A zknrf3bOt$YKPSTdb>!w!0_#9tGM@LLN(&J^EHV=(NsyUEEi*-4I#9TAvt80lsl0sS zFh!T-ZN^J)#F-QbM?K!p?M>^ETw=MeZFY%BHyW69Q4tVL~95SWYl*T=;3K;wIn}0*o8n{Ywzh zpE^CrkO)DMdP5=TF++f8S(JY_hVsv|9h+`U$~$7%5lF!!z>PBQ1v}}vg}4Hy``Gc4 zo&ZXsN7)0-QjLEu^s<@Pij2ZV>H^a7H_yKlnlC(ouzDQ;j?lI@x2~#pGwk8}lg9d~ zPDgwi)wzE3)bNN{yuIq1S$ZBSeLFm*zh=d^Ex1h;h6)2QlubZErWwLp7`h!GDGk1T z)un|YonZF25qvw#xZR5_eD^3<6+}lL>hZ#MM5=4~J6!9*f+ybn`gZsO*i|EZ6#!{7 orB7?d#zhF+peg2EbN-un{l9$s`~u+Tlo!eml)n(aGC~*s8^hmF(EtDd literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/clip_embedder.cpython-310.pyc b/tools/modules/__pycache__/clip_embedder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b534ca19978798620fe29c1490cce21f77b58a26 GIT binary patch literal 6728 zcmcIoNpl;=6`mb~!NOHsB-v)X#D-O(<0`u%JC2t~lqE-^9gA|65jtZCJtPMNW}u#d zM4|?Vn5y`Y(lOVVtjdKs&)xd zuiy7|hlPTp;CJi7Z&$vWQ}Gp;AyjIHs)n(&Wgsy*~A00oMV&h z3|jJEDuz-nK&=@ub(M#0FZi(O1($DJ|7f9c$79Unk1siM_}z}YXf9s!=6daOF|P#C zO2``?pYwuhh`w_u&c)tFEQbiAUWsCn!wXyW*l$!~@3H-fzjX21#aCYwMzzH%dp9w+ zrD~=-ReI!+Xg$e3(H|gUWlPyscNA4gtVF44_mn9Af%1hK>osHBl+ubbJL-m8?)>*NyW^Frbzid#%{$b4cz zjwB|(mMGh!JuTcZXfno(b&cOY7(wZ-#yl zN3Vo1id(JwO}DujhrGJxHekT|{2f0KdbQcw)r7^odw$guPLq4EwI6tlTd)ULNs!6* zC|8;kX4GnWT-dJb2Y&3j!n*O^rwbq75?UAuqgjd9M27i1s5Cs0soaTV4SKKmbq~Fb z8gvrbzDnX;m3vsMR3*0(Ft@+fkuasYd<35s);&LJRqDc&$_uAa*>LM#u!;{_E8O$i zo-pf`P3V@D57%cKVq!6juQ!@?ui*u;#}+sbc_}Z9$g8jLQF^UWfkEUK_#DdoJ0zZ^ zma%G!a}0Aa9j^*aHpOVOa;T=r!*HMne}+b??yf4bE^Jyz{3ftBA3wW(CWe}J#*_9_LuQb= zN0oH${V3qk&L{n&IfxkAC$z6M(mv6)u_J4S-06CvLw9wtWb$+9DeO30_W~adH=V{J zEMvU3@3(M44eml-TQbsF7!;_4B?m@AZ6BJC$RL>AuWpLRn>ZVTV2diST8Ag8Whe4%2P@6$8X|!Tb%#ei! zZRmDQ!mr>H!mbAhAsD->ON;Wis8-!LWJvxds^3P6iV)-&MHsS(CyY5i#F#@CagpD^ zJE~;X-jB@NH}M(@F&SB{X*J+9{Qc!+t)q90M2`{4d-cefGYTU0S)d9~x!Yp#3vs4q zZ)bOu=tAru_N$4urFYc3nKaK-EC!V@20Mq7tM2AR76=CzM9fSLo$JR-kC(I9K6mP2 zbv+U~f+K$pW6I;SpWJ;!z-sWP&>-}hAB(&nxlxt-&6rln+3&V%rj{lvq>G^XLgj}o zejUy4BR!&UI;GC2nrfJn|17FS4K{yeh|Pzs_5qJ3=b}V2mpeUy0k>cdehD>Fk#3lf z?r=ymVWkWE6BHLx@5bE~;L=j^X?9?gDj@L62e< zF_!2ogsm8wX2RD0B3XQ5KW&?~Ii4MLIQ6R{9r#ft{f0BV3>ecZN-&WuM{Ait`5-{u2q_J)XyD6Q#B#Zn3NF{e>!M zM;%^)Dy>NGMPXSQAqMywSkQ@l?I=exlPrRo)8%W<0AJIO@U>hg-zg+H@U3Fk^xIF@{Hto`Gk1f zRAO-tXfrLD9YKKO%`u9y}c*YC>b3YymXF_Q-$8WR^NQ{9`lb9hfOM-k&!X7_QIf@%8TcT~1 zMj$O{3L^|W@U<2Q=Ov*o31=z%==zQ8jmoMgvaKMx+aiWm&fQ=0LH595R>9!u4Sl{P zV-SiWWrtPb ziYt}ia%0JTCU+Ov!bOUxG)S2Y9XRsD&@J^@M7Lw9(+ej#`{@LVzk*Nl?~!AE1ESC8eeE;?dY1prwOT z$+jsEa^My?(i!`)ALkacW5H|81Wcw&uuJJtjz^?tKlrC)fDWpX&!W={l#fZsLz%YW z^B_*)Z3PG9QfAPN+PF;Z^ICtgrcF($?CH%|Tf{|&$Y0VSjn zY^{BIs7M!y{|tjl+TvnL=tlVGG?NO6J0u8xhaXbqmX}cXbK^R|pL&+Y`58P?7thqg)m6Hbkhg0lt~YRlxsGPKX~XU8d=Ll? zX-bxG4Ge#u@85bv!r9OB0K>|YI4abmdG4+HQOq~fdst3RDT-;O$J0t--IV8ePH{xi q5~<$64CP!YYzh0V^vwUZic)u)HuequX)on9OP^G&YeVE*l+z6on literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/clip_embedder.cpython-39.pyc b/tools/modules/__pycache__/clip_embedder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73c8c162659dbde0b8c113bcb5288cdc50569115 GIT binary patch literal 6645 zcmdT}&2JmW72nw}E|(N#OR_8{PC~><8@9F5q)FDi?%6HB0tj0R|7Am^e8q+5CDskfY4pywXvRG=sQ3kC}J_hw0oqDeVG)2|ig?R+eE z=e>FJe!n-W&ds?RuGj1@HnvY|+TW=$epzU|jS_tU!8OhrS|z=+3R7*pqN}@6G1T3x zSn6&!GnEWWpLLqqimU2+BiGDV@=WvdyK{Tmdm1;mc~9e}pQ#sij2%{)AB|Yt#)v%{ zIW`)}a0esKXyiC&*R@i1A9K{oB_^GVSAKl)!i9_1YSc4+kv?=q>kz(1I$-_`bJw}C zsaH(C{F%1FxWlveve3lkxqDj0=6ODcdxjVIJnjxZ>1XRMKmI$$7x>~mtCHhO`~*hw zpJ}F6nZs;W7h_Qkqm7nm z`l9HEwHBT&qFIdn?O6RpnvH4{%N$edlFF3x29y^}f z!PmHUUF&Dc(x|mNeO=nzza7+k>9&OrLj)k+0vMV+=Uq~m!(7{s0@ z?aS|cbn)7CnQg%^Uab+d<&!nx!#Z9QMr?T1kb5KJM$(E~qP8XG@H%PV4x&!AAuaW! zbeq*}ui=NA@s_kVgztBKX*H^M@x83Nc>&*+i{(~)so8G$O+So%eo=^)C}pJ?`Hc-x zpx2t!;TOaT+Ts+6CrF&6VdnLj%zBt>HHtj1uRS<-{*85bgoxIID5$sAKMaF+Lbbx{ zajVsc)-k_M!;jWUJ@+i(t=r7T-GOu=QS|OJ(>-!^ zXv!LyczEfkkk6b)iJpVNN-Dp>Dx$ut?`apbpSwNA;VYhTYe}uBUVgW4rdpc*ZCKWA zS4Fky$IwmA3IA3n5PsxsiYo8x;wijVTIvb$Eb6H^Lp3tlw@@NldyVDs*L`Zn{L_`H z(TsXuGaxh@b_!$de7*Yi4h~d4f|2ERy4%8mzj3!lR>{N*5XXv+0E{0 z(OT>xVzNZvF?#G)A+2+U#h@0ztFM>EtN20CO#vnu} zXPB6mw9{pM56?zw=g*E0>66&rgo}bcB?uMQ6vC8PxnCeS2@$3MC$2}ZfnbYz-t%q& z9tI;B&*QC{=ZOV+0|nkRu)j<-a?I4~_L?6?Eg{Ho;-?UDZV>8gexre(KT88j6sdNG z1o>bdB_h0VbkowA{twdvNIxtrITrr13$A*qh6fH^DnO!VHm5nMI8BzjiqqtH9!$mN zbG(3ihR^e3xI27Bahe5innk{(IL!&ZjFAIQb4s2*yx=z!dJdV5prD(gHqEK8@BfWD zJ|r<=`5G<4@FKlw&yfwHKS7K!6mX6ShGIOzP@G=2=OzvqiUVwNVoGGv!b<>yG zP8i+l5R0kg?ra4hM4%L#{v;=%AR@s2wNgHf2#UiLZi8JR+_V}Uit2*+MLK~gltw&* z(F&2bwrJJ-D5|);8dp8_RZ*m$Du+KqHByh9BKihIN?A;nf0#EwTiCN~Wg2yP{wT-$ z8VQUbrX!vwQ6h1$SJl*sl{8VtS;?6rr#6L~kPD13hKvc%QtSl0M61IWhfu5nAt;7g z?EC=6(j{EQ&r$}G;-$3G{+FnKV$v!;p+BPisp(6s``pn}{_l|w@jaUV_el`uOlm?w z_72LU=&yHQIZ~%t96mm^c;LcCYO2k59X0ykG>lwP80i8d>8TpZp^LT~ zT@g{bsd@j()V%+cLP<7;lE?W1Ffq%Y@LfR4BG8gDf2)$`%lsrp=F|y50Ve=Y%HrXL zSpx$;5-}%Daww0dgBntUKqdMW#289`l{}tMGOG|WGnT_U$lQ#j{x%Hb20!{>l=`612;npbJ=s{j{;vDw(;9$?Q*tdaIjWKo-uiiiqxreVu( z$wGu@!*qCRGF*Dx;{0b(o9TpQHTj~ zZr4K^vuml`59tiDK=S~Ckmy)F8?Is{_68E<6pYgBPk4B>7-5h~o5gpik0Nv_lP=}Q z)S>b?eLE1}MN8`2vC6r|_P5&F1W|uPi4-r?i#V@XX5AMaZ-XgNMJPA2!KY{pbx&cH zTK{^U9A|c&M6YMLeoCXmGHe-~SKenEXsa*ej={RHp1y&cx|A8qsRiY)D%;zChzn%U z;6W4IOYM(yH26A+H%O=?f-+?BCJB|+**H-NLGP?+tJ?M@nwvPcs?4iM5B^p6YsA?;NuQ-UT%JQ^eU!p|BJP#+-FR<>p$F9H(V0|RhGw}BF^c~{E2P-o4D?Qbo zPO?hrxsdo6^~u?i-~AaS8sN>@uR6b!9sM7H{a;8+BH|i$!dS}v`eQZzw*&rPM+=~L zA0MU$pg$w|*2_4sFPG-U6}(iGNqj)!DhZVWsbokUpei0q{+fm#l@a}nnxjlxT%+cJ zhZm`-+`Nj~=O_`K!cOGZi;nGPoSb8TJQJ6o!CUuJ*W%%pMzXNHr0q$g})g@vACO_0Sj_xSx7s5 zk#AabqS-i#2N_E7cZ~BL$Q$26kZ0RkO3pV>Fb0o#j literal 0 HcmV?d00001 diff --git a/tools/modules/__pycache__/config.cpython-310.pyc b/tools/modules/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33c5510b640249dbddb300ccbd039d16c56a4d07 GIT binary patch literal 3845 zcmZ8jU5p&X5#E`dot>TC`@?60v5gl4HV3}*-C+zE<30P%h6Dcv`yiQIGMU@zog3f$ zX?pg~cf1fK5)TPUq$r6JSqcy-5mFwaJVfCo;w?e~1bK;Qi4YMbuu3BSLjnfp>)t*4 zk{Qj{)z#hARn=A1ll1l31w7w;@<*-DHx~+jk)4fCA39U``0iq%Kn2pmg41m^M^k-0 zDmsQHd3soiN=`}9#n6n(PFc}fXgL-ulR>2?HK)RCW<6CUlgdx)-_V>IE1d6ps{iSN zGr($WV7TB6`USF{(OAVDxNZ8)2I_W2-QKgOYoWxwv==x^JAh+K?mmU%iZh{bGAPg}jeTFEeKZUl zrwQOBY`hlj9n=rpPd4yQ>~!E64b)w#e>dF&`XE(-_hRgj8oQ5fLjN$_&)WOcGSHJ|-!73!6QtYV5*^HM1^ASTB`D)5U z1qOL;81P(FBJYB`l4Qh%oiUz-?L0_gXK*(1e70e&V}~9F-2LO@QD+e1Mc+)&YPC6- zGc&U%h1pDa5HtVMv+w=ko9fBe8~*BJFaLFCT{JNG-6xKIc+bSE4gUu#ujgaWz0&Z%{`$7c z+|OT?gZ3YPmfmN-(ePhb{@qWn--neyoc>n$)~Bz@!5t3W`^M>)Ag!}|!usWl4ZpiN zh|9lz?EPPz`F%q6hy_bFa&>14!=hh~z<2IKm^7ETsCs!GyULXuien0%ce@eY z>%GztWy+Ein~Q657Ntpslg2t0%Lo&jHu7+0GD<{$8%Ke|Z7!ur5a(h@_FE};VGjjr zT}`sMG{oaeqVGsMPgtz-?x7PMq*0;sIH0G5%(v_S`dxda(IZvlbJAWw5UdG%4ZC@9 z%tD^2psvEB>l|T9$2Av)j<4KYRIbZLAwvu*`>*B;&-O|7(6THA?h?x&5lXJHwVe@4 zLdf?mPmWHwcTP->vd9}vvv$mA3|7jc=*Bb=Y(zZ~6&5WZ0pz*bX~y!*Lt2s8ay(1c zSbREV@$`w=Gsom46dO?m14kM}s8MUXdb=fHy3H_1-E=iiGO4>%V{9Ra;SjZ9DfCuZ zmK}xO{tdr)`+>2`?=*tv_f!i0==Ao#zj?rJjOf`-a`g5mw!HT8$L}=!gWn#1^E<1@ z4*MD|oe$b^(00cs58O3AK0cP_ zi))8^Lp=B7xzKuHabcCu;LlamivN!w_rk>3d7i{0rEXH#C>-tz3;e}g*nWoaRq-ig z&bLvEuZERV<0OT0RE>Pm+4e}l+g^B1I{F-vMwPyEiK?E?=I-GAEN0RqQd}+y9fIj= z79u2V=X_uA+g<2BF@uttWSF+MP63xm7{u+Q&G**l{BU(|eIaaT8QL&esb7trz4a)O zu^lpwnQ%fLG?f5^Lp!@FEM`rJu7^Ri(34nic})g7Kn6r<$Z}TqvWSnI5+!x)q8ui^ zAH=>WCwxp=D-4Cjgr}LjYw$~l3ypB0Wiok$o^VlYwtP{P`DgkN883_eGz#P`5yGj6 zYGE|n3pcf@N3UGme?M z@=6m`zl{>mi#$u}tn4mJO6pYQOUh2kb?jb84%mwb9c<-C@N4&u>iI(6e9nNXdW&90 zvNTXCEL9o#6)K%Lkc~^d40$&=6{&|8BRE9OLy2^IHBwlry2!$TI^(&-n%-)+@OSTM zJQdd`sYILexf;9=Ra7do{S z1Y|pf@mTN3HQuT!Mxrb&fDWP(CP>y+mPC$qrdRNs+H$bALCUg9QCAg1XRy}`;ugcL z-DJ+DE2R3X>E?9}v42H`$^=xFOcJw;64SP_|pgH=QtvB zKH$OmWPUn{TB~!j?n)3-vVtH-j>}55J}}>GBkv^Ldhld>&N} zKP`N|Tj2i{@a-35LSXca82rd~_4^}(X8p)J{(sb0I-Z58UaZfUx%YX)BH zl5WaYQdZf*DoZmj=~Z3Lz?{$Y66oG*eP9kihHh#$c95T>@otc>>xQOZHgx@Sqo{vo z8u}%pq^S={1q_$BN{q1A{C#eS9L3JAO@}PFoDDWm41Kv#Iz+E%}yoI&_ zck4BKAZfdnxtAtACHq#|;Tf!YbvwSDv~#@(-afjCZdRlJ!)>tJsPb%%M$k^8okVL5 z+g+;VYVCfyg?3Yw_RwCsb-h6Qs6n^Ue!89RpgY%<_Mm`w(0<^ZbQkb0IsiNX-R{=9 z-3_X#sV1m0QdG9{u9p^CJ2YV)Igbw4qf4Ew$oX7V7RzJ92!}vpV zM0@aldYF#VBXn%NfYtNV0zNb>1<<(wIgDQn&mj58!*R_)|byed1lGEc-lVM&p}?Z@SzfG7{6^L6*L<7 zM+L*^o9OnDT?T&*Gy4XB8%4kcP#*Yr?pWZ@+~^%%Xcqeou#?&*ltJ~e3hAR2JK+kp zPzgiIl}cnu95hkP~Yx ziJX)r0S{f@5n&>8Jy{3@leG@N&d3pg>-(p??|oN2ueSPg*s@toM%2!xXm=T9K9i#x zloPv&8^9XohR(o7KN4a)Lzi z^s8^Tysx~syL$PD@2ElR*FT6KvEFZaulIiO{aX%U<+m5V=70FGn~*c`+cKYjXdKe_PBmR-Jlj-?lu_yVG!SxnSvCshUY4d^d`f;1qWPX=K& z?7iJ;P7Hil=Ai>aWkEq{RnxV06sGKYsyme^8xW@QN09l56?L3l<|0;1)#!PniV3$m z%YM{e5whl{Y3OJ#3t6EgiXljDOIZ*{3CqoZsY zVX!8xHSFewA@fC|!@35?Y;a^*IicN^d3=?2SGysZf($vX&AgE@A~^)MJ<(Iez*%7l zBtpqGwze-oQSil~-rUrzbN}qz6bsy`ILSgr)36d*IgDx|*ocZEt1MVX3aFE@;)JD% zi^L+i^HdV8vG9D%!o@SEFPv19P;O)m3>*&+P^H%P4|j{eblN_Top?2k5~aIRV|tl~ zaEMN$lJBmvBsmVfy<30u(PPtJeB9!%9jF@K)Z(5$eK2pen)&2*HG1!JJKuflUmv%; zBVV8S;2Wzak9i&a{IS(ne+~AY4J{oJvo`PWw#p}{<8da2IsAHSZP!q@t2_&N=FH5^ zKQJ>hGaaYhwWGr!k-F+!Xrr_^u*zfb=PIhj|3{E>eRle)h{9$`8)RE39nLkzy>2Qk zFF}CnFqA6mGgRXnVP)3^DWL*&Bkc}W?|zDl%=IrRZ(l+d=n8l#(nZwXF&KQ5g-qE+ znk%HyC)mGcB23~e6^HoG_oMsV5-MwyVA{b3aiOw`hgp<~gN@6czk0B->}N@W){m|= zZiLXmMi8k0_nE*%@Az4fjSF=8k)D2Dh>V>DlfgaHX|%g(}dpAW}-k)IK#;Jur@h z5poP+)L|)N74|@Josf2|$5Usc7>%mV4VT$9VhA4ux*NoHHGs3VW_ESJ$3D0%+@3BW zh!5Nm%<)->%T8bgA`PWQ?SO0>+>v&T2i`_UDT@4Bc9l-bVj-)Ww_bvyHl!{_#Cwjq zZtQaGsGCjJybN_<7=EVGS=(Jv6m+s{Vny4jxVAOyr~#{s(7{$-0KX2ds-Z8`1!#}y z(${qpQuTpi+tlSzY*O*;d@`f-Dynb6dTte)^=A1 z^?h(v=lw+12dOk1EG;wqM96}=$gOra3s(j)U}3HL>v5KaEV+K%b?pi5anv>k!NZcC z>)Z8f2*@mk@mO!GHSVgeNV1|VfDW?iN66St5(Tz-VOaI-dXKMdkgbASj}Lo1 z>@Zy3ZDw!ZB=ygYPHwx7*xwYPGXdo#kww8!Q^)Yff;>Zj#Mn4aUon#FaG!>M&n_Oh za#$k#aBYcRK{*OZ4VRjW9f`?0yB4Dp*%!QY(s&*BL9pO}++^mDViMe;yC c_6851wh7t~-%rhYIbZyuT*OzbQD5_SkK2 zb$e%b)K(rQAx|p^5C}>Hb_68LFW_e&@qpB?Jnb8AAYPX5RM~S)0&R7js#BLfb?SVl z&T+p~Dj0a)|J%n~ADlCce^BG-W25mB-uR~|gdteh@QCuR(PJKCn#X&L8_ejLo;Bsz zo{dpv*XiZF9C((nyKXP<1RL$*`sMGguU%fH2F}I-)S<|WTMGO zn*2eNKWXx@CZA|>Uz0yK8Et6#kd(BM3wYx?io_VR)R?dlOPOFR1~kAo8eGz(2UIk8 zmhTV3$ZvMrgUn6(J7Lt`3&D2cM7E<=<|tl#-;hO&|Bt-z!sg95l<{UeZg={d7y7-e zyEk9)ciWNZ@5cV1FO%SnZaC}MY!2h3-}9rs2;&*kzZtbxqjoPy!p)@L@5YZ5A+x+r1A>>SB-C@qOZD%27#VMI(H zuQN%<(8PP}A))AryC_1#W5V!wp5e#yrg|-L<{e{;30pXC&E3&v z?&U;9RMGE>ny7>3#k^PmEr><21UmN)^xze7{L5K&eWv5fTE!a|C5+s7K&ooyNv!EZ z+NJg!-uO9~OJeN9peBH8+ei(cC+66iP|ZH5IjMtoPQdJZH?>jblR`3=vQDv68kZ-= zJ|7iQOFotIaV4ov;0BUf%JZRQ-%zn|f2DR=OIpKF9GaOY6&EH7-iO(4cLa=tfFCgfQHO$4u+W%-bunpWObU^ z?zN`!ciVB2<-=&$3nlQcK%1G}Fe0GI=605g`bsxh{$PKyLi=@OI0%Czvt$s7Mh%J! zyIZn~E^mI&4VvM0zbo)9t+CjvO}9?!)!X!{S{67KElun#LXc0{;}>Yvr6laZpFw%h zPH&L-kg}C*dwIQI(e7o=&Tb%ESmv}lNOq#kY>PWkz$xJ*LCdSCGo}O3(s=WSj00tQ zMg28(rty8aDIcCDh-2SBbVqU->){SI;`1m9MTZygbGXACUgs7suqt0JXAe0w7q4=r@539k;7mKlN4Q!#>|J)v*fdAx$Ql9j z9X>WE1`z!hAlez_z_C(i#TeVE4ZOH$JA&C1;L&-xZ#6>3y*TvIAe?okx6Tf3P~|_I|waS;1A$o=1#8*Ns90(#c^qZarky{ ztCr+C<+P-|0Q$?LN?J@SD+W9Qd{bK8$2Ean+!)n>>l*m=w5;enX4W8geFa`8oqwCY zn;X~DIrtW}N8Ojs3w8!0sJ~?jF6?_8<1K2B7StO2*cvW+=hSM7|GD!Q(?#g14xh6? zUxU69XpF<>I2+2{D91xSr+}-s2`J)Ph5cB;op#)qhpX8R24Uui@OHZy$`u@KX18JT zv3!K;9B_OL75P=wv%4M2P!cZWqr|9Py>{j0YwOpqzw+9;fBnMhrS;70wWG|9!fx0L zBj7YfERfkiEAH_pXpY$o68R)*jV1XsG0Mue0ZLbOGJLpqNHFP8}*-U`xzmdry3pqK;2K<_gIT9;oda8cJ>Cie_qi6l&A1&`d3* zHWO-drfRdHHnf%6a!^zbYIBszax<05SFxkVKpUp%4*J{qSFjSK(G?QKIAEza7yN;=|)yL5yTv2sC%tpX)j~N!ENn|lzAXi zQEnpN+`;WT2$Jm+CoO+~ufI#1qEytyW~_&txexaiPi=`-KDG-wM8bmFg@;OlMSKS@ZGz$efeCYpi2M{0Q%vB%=2PofUoO=*o{hdQDJ|F=@^O3tzPQ*627b2> zN7(L1@(P&pD)rb?1$6s?(D+bDA&|@rfDA3ONEL_kB<)*86Pc|vc~U_)(ElFZnBt@| z`~$oNE2`J}G&RxpkL8pm<{hYsPJ`5xsH3LfpYkZBFsLqN3kZI3o$`B(6o})urs40` zG4a7eJPk*~{X1Y}76pU)wsCF;-C-DKW<2aQOnvNypo`FpV%gZMD~dEy_IzY%fegio zt!9rLkHKhYJRAFad@goSEMw~y-+TI$eSFU3WBQ=1nZj>nHl%$Ry-JF%8w#_qS#^CD zl)x07#|@?tOupxcms$vCiJ0D`8v93#QK@5r?S18xke@c=m~^zo18BaEhCuG0kc^!CAa8WZoze`rW)UCLmw=hiF-T1r=TZGKJI@6G3QQpeh`o-0o zmu~o1FRfl!|K62LYwOqi8&_Uk-;gh$Uw)YiIs%;=mh~D)y;Qc6IhWO6JWAFh$EtR& k{!wvX)Y|hdlDGc^y&MYbNYOz z&rwvXRT!ST_ujwp!NZLGlNtvf7mc%cvmc`nOz<8Hh*~|?=K<%Y&+2oFan`p3XVT*a zE@s(1uU`sEy3gtP{c=!7-R!NqI8C-tZlx`Ui;gcHJTb#^JY=q=A&9+%TJ1_e6|4U zHkkYb{dO_6#4uLR=f?9p?5B&I6_uinp6MUJ3pn^~+nlA1r&#uvC&4HEUd5Zwb?xp# zF+Dy($30#wTw(24!WPawYt;ChQI|#w*7b_&8J3=;S>pwiOOVUIL94xc7J$(xYM5ow ztfGOMOSPM?3xY~%RU4Gn@O79~O~yAj6B%b=M@HfTK;mQoi=2awTPKy9Z4O~&ei$Zc zl82$1Go}qonwYyt+PNfaQnf3QjQVjN%VqQ(oC#kP^+TGenzRr6(!OchH8cTmUuhsc8emJ^NR7d$_5* z_)Z?DLN#e+((g?8?jRfrRp)o&GQn!DT0356VVfk@S_faXXS7Y;)WJRJ0N)i`w`FCV<=d)VtG~tKaYE;H7y4Yv0?dE zgIBEzKV^Lx?@83QmX11u)4P_}reofH$c6cMwd?#fJV6Jpv&%lhmC)ty@~dpk9@!&j z1iW{xvAx5V**pBTA@J;tO6YS6Z;6fF!UbAKh_{7C^!cPoNF!z>npq8I+;Bs;-Il@z|F>y0XTweY)$2bSOs;9sVW?+Z?<@F>-B zZ3px4?Z4i9BrkRAMa}pCjMqogMYWh-V(OhscG*8;rK$DI;D5h{a?6# zzLABmGIj6a~D^X z-A_{Gr*SXt$0<;n4P>rdU={cFqqN3uN4adF)>@E{Q%7A}HsA^ShqF*MWAwXC{)XbI z{3d3~(=>hNAdd7>HGi;c4>D@{U^gHbOp_i1X6+#bsK`o7MPuw5+OjC0kqIooyD1jX~6>$lBWKW`C z?xHmZ_*VEFVC?ZaZvfN{YtgE5%W86hwU2U{pS64dd~5L-s3%*tb#O#m;-QS1(}$9h6*w-`E3-fvNj6&J?c1%BB<_}18$FfJCMY%T&7}$3W`ql zJ)xSP$Qs^G@^x)jbot!Dq>Mr}=6Sry8&Kr}-#YbqiaNabkSioNxTm(OXlS+B2h?^3 zYBR3TE*z~k8*1|=YIC7Bw6)qwP*e$O^R&uJyDF0}Vq5ur6fOUVQvPo>%p6n0TiDO2 zz};0rlW#rpzg2K7h^Bl@TMZqfNe>YHz)A%F&?@+jMY+}_O;FQE0+@px`c7*(#>Y5W zbjvs*Z63%lw38^eH*oU~qkR3yNz2Rl`nx2Q&K>~)=TnDUpKi?`6Nz>{Bn2HJVL?j~ zppwJtDZLRE;z6rGYzH50gW>>z344Nw<^&NFOyIyy70#itQeo~p7h}(9Ex4=ltM~#u zakU=}!`=XHFd3wh@GVzq#GNRhH;9D6heir%ch>-}pdL7+iv4wx5Ue8!bawe@VD6Zgh2zzmrTK71q z4~Oqe0^g^x@WH2eTAl&>v*=I`g@NX#@oq)E%{Wtbw%Kpl=Flrq4?!12v02d6G)xDw zA0i`*WGs#dsvbHVebLlxckb`-xs37~#uwo@x1Km=9bd5ckUl8arNCPo4cQz_UZ+FX zm_{rtmX1tSG{Y2^XDx26=a+(iT`RCc`v-1$QkZtThP)vfq-9K1%d(;H-&s#CzdCk>8$BYrc${|9-}r$|6~ z?Lj2*=fDw_E$%JBMT)zZ<^W4s_Ax$U3|vRK6<*;B*2C6T>kCwPcn9QYlN8t_!g3Z{ zwrZ4wrO2<4gjK@O{3U^UH4I^dJu{-?ZutU9{}L4!sUW{3348KYDt<)88z@w561lfy zSob7)*J6HyrqJyqn<&VLvag`<-AM*(^4NdZpYk8~Ykt*t{EF}Sww~Q;8d#BK2D3Va zI-RDcZ)~P*ax+m+RSwko1xc5eD(gP_vqh(KYHl-6df90icLKRTRU2Ob literal 0 HcmV?d00001 diff --git a/tools/modules/autoencoder.py b/tools/modules/autoencoder.py new file mode 100644 index 0000000..756d188 --- /dev/null +++ b/tools/modules/autoencoder.py @@ -0,0 +1,698 @@ +import os +import torch +import logging +import collections +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from ...utils.registry_class import AUTO_ENCODER,DISTRIBUTION + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.no_grad() +def get_first_stage_encoding(encoder_posterior, scale_factor=0.18215): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return scale_factor * z + + +@AUTO_ENCODER.register_class() +class AutoencoderKL(nn.Module): + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + use_vid_decoder=False, + **kwargs): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + + if pretrained is not None: + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/v2-1_512-ema-pruned.ckpt') + self.init_from_ckpt(pretrained, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + sd_new = collections.OrderedDict() + for k in keys: + if k.find('first_stage_model') >= 0: + k_new = k.split('first_stage_model.')[-1] + sd_new[k_new] = sd[k] + self.load_state_dict(sd_new, strict=True) + logging.info(f"Restored from {path}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def encode_firsr_stage(self, x, scale_factor=1.0): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + z = get_first_stage_encoding(posterior, scale_factor) + return z + + def encode_ms(self, x): + hs = self.encoder(x, True) + h = hs[-1] + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + hs[-1] = h + return hs + + def decode(self, z, **kwargs): + z = self.post_quant_conv(z) + dec = self.decoder(z, **kwargs) + return dec + + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +@AUTO_ENCODER.register_class() +class AutoencoderVideo(AutoencoderKL): + def __init__(self, + ddconfig, + embed_dim, + pretrained=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + use_vid_decoder=True, + learn_logvar=False, + **kwargs): + use_vid_decoder = True + super().__init__(ddconfig, embed_dim, pretrained, ignore_keys, image_key, colorize_nlabels, monitor, ema_decay, learn_logvar, use_vid_decoder, **kwargs) + + def decode(self, z, **kwargs): + # z = self.post_quant_conv(z) + dec = self.decoder(z, **kwargs) + return dec + + def encode(self, x): + h = self.encoder(x) + # moments = self.quant_conv(h) + moments = h + posterior = DiagonalGaussianDistribution(moments) + return posterior + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + + + +@DISTRIBUTION.register_class() +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +# -------------------------------modules-------------------------------- + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, return_feat=False): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if return_feat: + hs[-1] = h + return hs + else: + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels, curr_res, curr_res) + # logging.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z, **kwargs): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + + + diff --git a/tools/modules/clip_embedder.py b/tools/modules/clip_embedder.py new file mode 100644 index 0000000..cc711ce --- /dev/null +++ b/tools/modules/clip_embedder.py @@ -0,0 +1,241 @@ +import os +import torch +import logging +import open_clip +import numpy as np +import torch.nn as nn +import torchvision.transforms as T + +from ...utils.registry_class import EMBEDDER + + +@EMBEDDER.register_class() +class FrozenOpenCLIPEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +@EMBEDDER.register_class() +class FrozenOpenCLIPVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, vit_resolution=(224, 224), arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=pretrained) + + del model.transformer + self.model = model + data_white = np.ones((vit_resolution[0], vit_resolution[1], 3), dtype=np.uint8)*255 + self.white_image = preprocess(T.ToPILImage()(data_white)).unsqueeze(0) + + self.device = device + self.max_length = max_length # 77 + if freeze: + self.freeze() + self.layer = layer # 'penultimate' + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, image): + # tokens = open_clip.tokenize(text) + z = self.model.encode_image(image.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + + +@EMBEDDER.register_class() +class FrozenOpenCLIPTextVisualEmbedder(nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, pretrained, arch="ViT-H-14", device="cuda", max_length=77, + freeze=True, layer="last", **kwargs): + super().__init__() + assert layer in self.LAYERS + + # modules + current_directory = os.path.dirname(os.path.abspath(__file__)) + # tools + parent_directory = os.path.dirname(current_directory) + # uniAnimate + root_directory = os.path.dirname(parent_directory) + pretrained = os.path.join(root_directory, 'checkpoints/open_clip_pytorch_model.bin') + + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + + def forward(self, image=None, text=None): + + xi = self.model.encode_image(image.to(self.device)) if image is not None else None + tokens = open_clip.tokenize(text) + xt, x = self.encode_with_transformer(tokens.to(self.device)) + return xi, xt, x + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + xt = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection + return xt, x + + + def encode_image(self, image): + return self.model.visual(image) + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + + return self(text) \ No newline at end of file diff --git a/tools/modules/config.py b/tools/modules/config.py new file mode 100644 index 0000000..9a8cc40 --- /dev/null +++ b/tools/modules/config.py @@ -0,0 +1,206 @@ +import torch +import logging +import os.path as osp +from datetime import datetime +from easydict import EasyDict +import os + +cfg = EasyDict(__name__='Config: VideoLDM Decoder') + +# -------------------------------distributed training-------------------------- +pmi_world_size = int(os.getenv('WORLD_SIZE', 1)) +gpus_per_machine = torch.cuda.device_count() +world_size = pmi_world_size * gpus_per_machine +# ----------------------------------------------------------------------------- + + +# ---------------------------Dataset Parameter--------------------------------- +cfg.mean = [0.5, 0.5, 0.5] +cfg.std = [0.5, 0.5, 0.5] +cfg.max_words = 1000 +cfg.num_workers = 8 +cfg.prefetch_factor = 2 + +# PlaceHolder +cfg.resolution = [448, 256] +cfg.vit_out_dim = 1024 +cfg.vit_resolution = 336 +cfg.depth_clamp = 10.0 +cfg.misc_size = 384 +cfg.depth_std = 20.0 + +cfg.save_fps = 8 + +cfg.frame_lens = [32, 32, 32, 1] +cfg.sample_fps = [4, ] +cfg.vid_dataset = { + 'type': 'VideoBaseDataset', + 'data_list': [], + 'max_words': cfg.max_words, + 'resolution': cfg.resolution} +cfg.img_dataset = { + 'type': 'ImageBaseDataset', + 'data_list': ['laion_400m',], + 'max_words': cfg.max_words, + 'resolution': cfg.resolution} + +cfg.batch_sizes = { + str(1):256, + str(4):4, + str(8):4, + str(16):4} +# ----------------------------------------------------------------------------- + + +# ---------------------------Mode Parameters----------------------------------- +# Diffusion +cfg.Diffusion = { + 'type': 'DiffusionDDIM', + 'schedule': 'cosine', # cosine + 'schedule_param': { + 'num_timesteps': 1000, + 'cosine_s': 0.008, + 'zero_terminal_snr': True, + }, + 'mean_type': 'v', # [v, eps] + 'loss_type': 'mse', + 'var_type': 'fixed_small', + 'rescale_timesteps': False, + 'noise_strength': 0.1, + 'ddim_timesteps': 50 +} +cfg.ddim_timesteps = 50 # official: 250 +cfg.use_div_loss = False +# classifier-free guidance +cfg.p_zero = 0.9 +cfg.guide_scale = 3.0 + +# clip vision encoder +cfg.vit_mean = [0.48145466, 0.4578275, 0.40821073] +cfg.vit_std = [0.26862954, 0.26130258, 0.27577711] + +# sketch +cfg.sketch_mean = [0.485, 0.456, 0.406] +cfg.sketch_std = [0.229, 0.224, 0.225] +# cfg.misc_size = 256 +cfg.depth_std = 20.0 +cfg.depth_clamp = 10.0 +cfg.hist_sigma = 10.0 + +# Model +cfg.scale_factor = 0.18215 +cfg.use_checkpoint = True +cfg.use_sharded_ddp = False +cfg.use_fsdp = False +cfg.use_fp16 = True +cfg.temporal_attention = True + +cfg.UNet = { + 'type': 'UNetSD', + 'in_dim': 4, + 'dim': 320, + 'y_dim': cfg.vit_out_dim, + 'context_dim': 1024, + 'out_dim': 8, + 'dim_mult': [1, 2, 4, 4], + 'num_heads': 8, + 'head_dim': 64, + 'num_res_blocks': 2, + 'attn_scales': [1 / 1, 1 / 2, 1 / 4], + 'dropout': 0.1, + 'temporal_attention': cfg.temporal_attention, + 'temporal_attn_times': 1, + 'use_checkpoint': cfg.use_checkpoint, + 'use_fps_condition': False, + 'use_sim_mask': False +} + +# auotoencoder from stabel diffusion +cfg.guidances = [] +cfg.auto_encoder = { + 'type': 'AutoencoderKL', + 'ddconfig': { + 'double_z': True, + 'z_channels': 4, + 'resolution': 256, + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0, + 'video_kernel_size': [3, 1, 1] + }, + 'embed_dim': 4, + 'pretrained': 'models/v2-1_512-ema-pruned.ckpt' +} +# clip embedder +cfg.embedder = { + 'type': 'FrozenOpenCLIPEmbedder', + 'layer': 'penultimate', + 'pretrained': 'models/open_clip_pytorch_model.bin' +} +# ----------------------------------------------------------------------------- + +# ---------------------------Training Settings--------------------------------- +# training and optimizer +cfg.ema_decay = 0.9999 +cfg.num_steps = 600000 +cfg.lr = 5e-5 +cfg.weight_decay = 0.0 +cfg.betas = (0.9, 0.999) +cfg.eps = 1.0e-8 +cfg.chunk_size = 16 +cfg.decoder_bs = 8 +cfg.alpha = 0.7 +cfg.save_ckp_interval = 1000 + +# scheduler +cfg.warmup_steps = 10 +cfg.decay_mode = 'cosine' + +# acceleration +cfg.use_ema = True +if world_size<2: + cfg.use_ema = False +cfg.load_from = None +# ----------------------------------------------------------------------------- + + +# ----------------------------Pretrain Settings--------------------------------- +cfg.Pretrain = { + 'type': 'pretrain_specific_strategies', + 'fix_weight': False, + 'grad_scale': 0.2, + 'resume_checkpoint': 'models/jiuniu_0267000.pth', + 'sd_keys_path': 'models/stable_diffusion_image_key_temporal_attention_x1.json', +} +# ----------------------------------------------------------------------------- + + +# -----------------------------Visual------------------------------------------- +# Visual videos +cfg.viz_interval = 1000 +cfg.visual_train = { + 'type': 'VisualTrainTextImageToVideo', +} +cfg.visual_inference = { + 'type': 'VisualGeneratedVideos', +} +cfg.inference_list_path = '' + +# logging +cfg.log_interval = 100 + +### Default log_dir +cfg.log_dir = 'outputs/' +# ----------------------------------------------------------------------------- + + +# ---------------------------Others-------------------------------------------- +# seed +cfg.seed = 8888 +cfg.negative_prompt = 'Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms' +# ----------------------------------------------------------------------------- + diff --git a/tools/modules/diffusions/__init__.py b/tools/modules/diffusions/__init__.py new file mode 100644 index 0000000..c025248 --- /dev/null +++ b/tools/modules/diffusions/__init__.py @@ -0,0 +1 @@ +from .diffusion_ddim import * diff --git a/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc b/tools/modules/diffusions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d69df26e7312c8db0882fdae951239163708e2 GIT binary patch literal 240 zcmYjMy$ZrW48GGvM12M4I`{x0qI7a`(Jpcj)U*)kT}pFB++Cf024Bh5$yadlswjSt zgzpEEaJSo50jKjN>xX%-$v+x_OhZX0G*qZ_wNltvs;K$UL5lQE4;GPL_mHxz5@H94 zWUs1h2K(KfJ|3Aw(ozfw?b*a^vtu%c13Bb_K*6QoZ1ePMm|Z2FEjY$hz9J-!W{&K} jr^YkqVk~2ohrn;M7yxqQ0fxo#)*62Tllo4nAL*zUeIZ3S literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc b/tools/modules/diffusions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1daaa1f05f3cc9e86879defe5259534f6c18524b GIT binary patch literal 184 zcmYe~<>g`kf|lOtY4Sk&F^Gc(44TX@8G%BYjJFuI z{4^P(_);>{(n^an^Yh|UQZjQ_G88cbrNP86J7=qy(Bjmh;+V|h%&h#F(7a5?yv*Fh wlGK=z{QR8anB4r7(wx-d7`Pe5G4b)4d6^~g@p=W7w>WGd3hY2OegHq)$ literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1604c918835946075a50b53c4a0f03930ef4fd51 GIT binary patch literal 24748 zcmd6P3!Ge6dEeZ}%-nhG?Ck99d$k&^o@-08WWX3BY-7s?1C}3xKo&WZ)t=ea?9R@t z&YhLyxk~8fr#inC;52G2P#cPpTfo_tfLF z&G?k+MSO2PexMnjR(*)?tH%$jjOs@_L$PekkY^xf#s*bR4HbsfaG|F#JQ7kP2#?mo zV+fDq*-;poOQ{L9^HO+vR2e5jQ@gHmmQGx?0jCT~jg4o&FidK@NlDlv?p2JUb zdNvKs-U*lr`w>+*SHu*U)rxPsm;m1{tKyPxpDPy5PuYGXXw5gOzHz}1yJu#W3!5Q5 zh$^02z5lLLZ*dFSJympz3zbv%SC;0Ue9J@m^To2NoOkof6IlRISjH=ZVToWI(K;aN#D#D6X=8M03@xX zF=}KDJ&c!GJRLt1Z};N4mIWw;uyaDwn0I9Y8!_jcS-*;1QJ*q+{ zL!0)AU^xuG?{#(TYjz0O8WH?;r+b%mK^13-bNT9%%Y`Y&?`=*}fVX-a{nHTu-)Y1D z`cB9%&*+&YKeki=db0+D%f)IwKY%wk1(3wz*oJdGYuZNg`s#3VaKUUJZWK7NY5Pt-U!GYi z)~ZMhDsl(3(!YpVhQ#kPz2M9|YLJz^#q9EtABZbuS+G{U95yeNzmP^N{1H zh>FfR!aoe`q{Awql9!BW8=)SRLMW_yRT{f2qWV+@|Dy$GJ+=`-4e{V_F0A@h_L2cU z6W=CnKn+S-qETYbRzqqSH7C1ss}b}isz%kAl<2uZ`ncL5>8Y;t5%h!{hCS)sCM~CS zNm{xqEv`o9B5G3Ymejs&Q}?L7lA5WHBBJ7GZJ*kYar6gqbCx=wrtqF^uK}^Qs9UA% zfI6rSVFm|P9Q5iozo)g{pjUpZR4f-}blz3IRakcYo)X^5;AMKIJYVpwCAaWw-=80P z_s4(e>}wu;7HcZM*;pfl|4BR~8zgeqzK9j~EM=Svxsf-7o-#aJnFuE(4AO>hhT$qm zUO?DL_G8#RnUM#%4EooHj{dQQbRr04lvoF(bWz=qM4f3i`J z8KTv(C!Y#P1= zW`6ZLes!Gut9#~m-}jN^#?xPV_2wQ;x(nPbmH}oi_;CmsWw$y*fh}CE=-D%%+m#i_ z7jG7UNN1QoI#;U9RPT7XAEpfG`c|n>_It`JOZjSXso++zJ%r@?&N*xt;eoMI#dU*V z6v@R>rA)pkV{Mi4Q{_q#jMA;@LV3P=#*fafEG_Gb@>4UVFBdLU!Eev%s^3qT?=0`{kVsUW$JDI2-N_cjR)3tb z>kYGvNDaC$TU_ogJ=RKEI@+C>S*~CLb+MxJm^_HcT#4Dz%+j(y(Mn>|jpY1nskppc zP=2Vv&zwb@k%}Ckft6@D>7Ll_wDLBf(6yF+AaQWSL}y z*X{&JCXzGj2rM{^C~Jn8aVo${73`G)AqgF>dTJRtGysZi~luh>GNy zeC}s+-vi*8Dzp;PQ&r0S*wZx#g$t%KKq03nHl7OCteRa5*CMq@HM9VUA5C^`wj!72N=uLLe|wmalIp-GlIEM3Z*=anO7Ps zNaswom6)YC)%Z0;E14@&Am~QmC(wD!ngvdRdBVtzin+j51z9TJ!RKowRkO2OID3_;1p<14=fFk=3q;(H$tE)=DeY>s}>^4rWEpH z#^a$6gldqkRCvYEuZ5HaoLmRpfo;NsT5F4R*XzYt93k!68K=(Of*8W@I|(Pn&3A6)r4q_5~RC zKIS7u^25-63nf1(|MGQ(TQMcRT+p+Ha7Ylf`%4o7`IDSI52EYeEkM3#W324mMfgH~|+@{|Lb+2%aX;1VjMH ze~?rIqO-!tE>hFem~JrQ%JLQDxJKNJhizlf8a5`#hQ~TdOh<{1Ric2<5L(=lg;k0w z142t5=9z>5tf_+#IOPSfW7xeentHs7wSbyuE}HIysE1q1C~G7n1SGb#K4sU_pa3q| z>tPk%F!Tc;G$2hIMj5Lggp@f-MK8MGKz%Zl@g9g!UgACG9F@1I7uyoIC0AftZRCWy zh&ot$-h|>vB@{~K0#wLFQ-9EN7Ge;=Bp&$XfRskp<6e@aw0uVLImq{GTk@rCdDl_OKVd zGDT%Ytq5Zjb+tCAzeif^K~IzAekq?o`M;5V_162))6q+qq!-dExoCdK&_kY6>#HUd z``ibmRO^`ut*_Q!>n}ebxq2nn z!%|zeUOR<5);>^wcY=Ditz}YbNlPsiDK*$t%h0#4#ft#D4UF~~gsepyp&vlVUPPSx zE-_U^7lyse`Ur7weH1JYslX8nDnDj+eH`X2n|~X2EsOcy;So-F*+j_8Y?vFSm&NL4 z$}nBjTly)~w$lqEJq+PCxXk;-D{K~=>*YH-;=s7%)Q*CV{a&~Jr# z51c#LZMBrV+4a3ALf~*n;=(~3_Cv4N3#As6Zasn>wXKLLJHNNPuJ(FJB|&w1y7PEF zTicU7jUJ|(Zwd6kQII)A-RgssK^|jzeZw$pd+FiOnEdxZ|Hv~a($+G_6Bkjszm?m| zD7zndP!oDp13g<$u-^3~gzO%g6Co5Epwdw8nnLyf%X@t&12Ngl4u!k{FN09OQ2#-1 z(CeuS&0l^=NhUb_e)f~KtHr+AXsA%jRw-8D{EH))En(88A z{Y2dyM>C$5BNP#-W2y*)8ub?aqb&SMf-b2oAx#Drg+D;Ldk`QVnXnuqWs(Tk*AJL! zn4s+IS(|SubJ*Bt#H~@l2@2g^rj@q(X2q2N_3_6^q3lhz;F2UYc%n;G_IXOg8C+_W;G1T6T+I|0R=R8I3mX#rWeE7L*iw7kX~qp z>Ot1GxjL z08=s;ErgBYY`fBBW8{JwKm_VmZRDXv8H5y22GK$?-p-I<2uTL$${t3%k3d%OhZ*Wp zyF}v>MEG^4Hn+U&_36eor!_uc`&~%TFKgL2J`XlE{dBG|U)-LsvcLoTx`{2EG3*y} zVQk-B+h9Xy4IPRZZL0}9+{V`^K{dB=Q+;mR#8GZyo~4NzoxrBFMU7@C3^f`|4%?F+ zC{UOneztATx*FYcHXEqX^{q)9+;q0xja}RCtIU}SX8jI?x7~iN32$yX{S{_@>D@NG z%|y$to8WA7j=1XvK}73{FU+KMUk2fiH!8k`Vn$Jwh38rh6|Y4CdFZ{-%c0B0W%IIi z*}fdU9Kj(oT}PZ+{Bq=S^s<9PS{#PsJ>kXI92g}{8s4BA^XSvm$3O`&)fY|Bz8+|N z*i)Xl7KIFAQP)N8UZnRzBaC2|rpvFdm!%&1uyVNDs-}vKK+dS=cx62Y^};!Y+Mz*W zcQX%VBL&$=r5Hw?y;A2yy@azE)|V8_awyT)RwC0@q95UGy`}+{M*q`o{qMoPPtx2v zHTc}kNGOlK>EV-q?+eYvDChXs=o0pS1>e zA{b2{uNuG@&`vKTMtr`BB?M+{tc3^HF3DjBIWV(fYKI;SuG629mtD;rDe^g z#YAM?u#{`7tGymCEO~HTax?Am5FwDpVBV%!@X|sT4hvcov_&4j$b!@#QHhY;7F%tR zSut!05f-vZbpz+Yl_j{Q74kS_eL45;IASME$J$=_0rx!Ixs#VmOK#mRdH7hV0;8^b z_u;lmJJylAIgCfv@dF4{Lu)*IG1hQ!u!f@rdB=+JoofWW!jw47^_qco5t}4aaZCf= z3I2#x1P4Uk8rE?glmODc18F~|EJ*wInz<0~koLpEbE#{fJr8yt99Tp=*qyKv?eb5X zNSp}PWCk4$6lCFVGS6>h9QoUJ;x871YZ9z4_d0_mm|kZ56evbL?&i8m(l z!jU&{aVXBDaYSlgAGR})KQoj+4}AvJF;NGjRWYnt(_7ZSMUVWjJFucJf%w8M5N5=={r%NjlfBIR$!Ze}mE<~`UDZ3ulU=RAit;*nu1 zJs9svicUb~o@)O-JV#79ZtKzk+Un?#Cfjv@Yh>3=;UNYc0C@Tx=m3d%NRVCvg2io+ z48PjWA>qXb+NZ zL7F%~H>?F4)D1%9NopUkfu9S(K6)!OIqD#==5(9{Bvlv|p%6sN2#y)XUBBwESM|-| zBx1c!MdfVDSb^^ZgBosITd=38TyBh$!=P1L+2<{=cm*8ap$lJ8T>(4NjtWON|vj{YQBx#@t z;>&xDS=Pj(U=C_(bvI*$mv1Y;{@_v}icfO`>99J((-d?g<_?>k^C!I{X&KEiy| zRRo2CtZ>i3S(_ZAaTWr1gYbC>w}s7g8~(hHiCsvf7)R|3PO~dV1x1J*;*xokakn>L zZx^cOTV?nxSi~LKk^%%j#?kHH#UrXf)*Odd09SBEJ5lR}z-*6>?hye4Z2W3wVlY5XCaX<_DLQx=qv4ess-oocD$RwY~A&p;Q1Uo>Yak6FUysU<&CtyIC2+T%7yw$)-jg&}kwX`Ex6usLqTZnFTe z63ha;X+{^?NJqF>EiAeEJuJ&P_2c4X9=HP^Vm%})U`i`XoBeI1q{W4r1VnTnKe|-F zk=U%;ruC3d1(qSM`gD}H((=&ug;77nMA3^bAvEpaqNU7C7jf~&+mNoEYs)nM3F^Ie z0OjLhsK3^*8HehN8-8L@PuY9K9Ff7X*f{tJc#W|NABuLH?@LwQVI`tK!$(NcI0BFt zA)tUT85_Vj5C=&5V?a$0S1@Tk8$j0}l)M>Pg9AuP89>54j1XGfiF9&&89@`*NC_k0 zfLnFFZht@#kT!=jXFyL8+(NLApoJB(#v-`L28fm`c{=pzb!PiX0^u#c!O-sze3js9 z1kV%T@?uC+2GD|EZ)8ZMrPcr>B)s4dN>CVeUk~7fOv^Od|Kfiezdasz@QXO`Tj(1! zGCh&kj2e+iJSoKV<4F!1k=MX~us?k=(gzQSJ({)7;_3J$@pdmB_aCE2H*xRmX!G9L z(f?of&Uot#H_V1{+f8nmZFB!@9LiNAS3>Pj6LN2B>pe8)V?8@nPI7eLL>tH5uw80W z(zSZnBWYW2qK)Gw+FrF!Qn%hjW9oi&KvFwzqD`q=aPzEn6YW-Y5bxbL(GID@ zQg%SyrjA@{+(f(G@9WrjZ=&bpVJu!i4302#l;C!N2Sq{ZFl2D8xi3U4e>pS0g5Y%o zX9!xjC}_wn(2vce>0Yxsx~1)#xR#+8S#OD8iC~$atC@3%dw|WHN1)5UokKhv9QTP- z-QnL(E*Hi_a-uH~JV~%h@D75mhDhEYVM8?pwj0m4y;q=rfYrW};9UgoCg^H~l;W*y zbqFcLxmlNeeVMhd5xke+9};vm_I||W+1M2XR&U+9L^t%3@4$>cl#bz%nWXt% ze~{qA1V2R3)$BUr=Gp890^7}G!~eBI(da?u1RiPS#}Lr}gy5eN{4;{CR(=w3PqLMN zj=<{9t%GZ{B5dqxj#UsRL~))awl{vH=y5z29@oPjSmVJV4vDCuz=(o7ws38A;AV(> zwpSvM6no%m8>`w>UhqU)d3fIt@*&KrLm?a}LS_w0zt_0TZwo*!z2sCK95Fdwism4M zQWq^spI(oO9Rr(F@luS63~Gk?2k{BkSxdw0+gEkGw77iq(Q$D-0h3%tq-MB3LAHea zOczu!-#%JR!rhf_vi;3awwe@2RUVGWnCLh;heNc5UOZ`)fQfRT+UNDLjSm@59EaF9 zSQ|q5%w-e3PqO#k5POSyhtaLqDBT+e!N@>KSQ)yxI+wEN|b+WNMw4L3vGKkUfY-=0rppAI)}IphDp)8TouHVR2R zzCwfQT`;HNekmmUx(MhjMsK}{j3FsN`c;vIby{a=6y>%B>aJ` zd1Z~Z#2-T+&PLBZ)!qxvMcufIuWZ!avL|ksqv5!{69N*okiiYQ)SmGsa8uS6*X?>p zS#V&=R``wQ z@PD$rC^<(Y=Nq=<++Af&<8U$HvE`oX7)tIy%mj}hC2p^`n=`CNF}B^_Zf`G)>$bNW z4(NNtv8TPQS4dl9($=EXMHld$=$Tb_L7zYvV`;dcOAC|I<|KN%tDKb@#-)aLO5b+Y z`!U(k(_P4yQ#&NzC%v)SzUr8oK-?}mp|e(R9}L?Z8(qMgu{_3wd3(P&q36_2DgVp%eNLoX~Go zSk3i=up~5`(5EotmPdHVn}U=1esMyd!V2#fC-k7DJ5k%=rW5*pZ=W~m?eg}E6Z)k7 z7vMvp%wOHf&}RsmoD1K+;WeWw4nFhpjhIyPP5dSO-o|?cUlS?KEO+Sl;Fod@lG>4c zgk_%keWb>-_|RbJ!=lr-I4zW^yCM1e25*SRd2<_%*|@qLzDqR9UUYL}~Z zkq}&~uSc%0zG6GB+;JVRO^a{VBj?g5rp&;d@IJJ%WegNlIRl3)o0 zfv%6JY38nvNV_OEWdy3T2o@sK<}ujMVH7T&&?MX>#IqrfFXGG>ZrF!wlpmxIeL){^ z9t5Jzi^E_R0^9pCWy#v=(`YL=V$d@P!d(Ax)YF<-&CYxt9TD5eEikJ{H()G3j|+-h zyW^m=Ovzo1?u^{1J1WFmhj5jGDot)a6lzV?R5<9~JiE7L%tXVJc-(BG5y2e=7SWh# zB*-AJeu6t@vDiW=`6`W!>Y537o0-P3NWV&40FkYB*q5mKgIbH znfFx;iJJCVhPaF`eP3kiHvQAcgF#U!M9>6ICO(Ftd8hX$Uw$btJeU=Z;zPPoJlXB&|hGH|y+Y3l%NiU;s# zf6H_lmM!~R_BZ3b!#MXkXdJ=aoyq7Va;1$s@rU;p9QgFM-|I<6c&~@&UafmQNx9d< z6SVK@UQegTuqf+YcYE|NVIKBy_UA*bKo?pKaGF!T>^zvF`&GZF`{5c-UZu-p&jHG|`i2n3cJ-|>N$ zX=h#S_3$nXggSVrcIM#?o%VcePu?4Me6r0qdR(;b_zX1f_~4^E=wG&e$43PF<{h8G z<{h7*`W+wS_6D1Ge9)`f@J-$E8Ie0a_+SrjPI|*oWJbI}goZ@G9`#1OKAsoqU%~q5 ze*-Wzb~6;D>8{pxd!dt{=oN03UuU(lnPgjik*~KA6bZgWu!G<%!Iuesl;D#D+<9$` zjJD&r1uQzqKxE&|UsUP*CWPRv_+bZzv)L5tYy1WfAVn2)C7EoV3Z~$whD&2S4{Y;u zPrSrsl{9P_!DSVA5#X9-V7b^0+KT>7^i~eKrtB`-9xTOgBfW7)QKq|Vy|z1b*>0I- zT{RmMO86Cqev9DK1j2Jwk|x*Ku(RA#U>egZ_QC@E-{NBf%dKbT#-j#68Le zeFVBrgl^PeFiaZa8b6|96b~v2>Hh>UHK_j(|2mG1D7JAfWa|Gj0{XuY{1L$)6Z{E* z^iS9pxz!xI_NNG}-oAC_+uZ$H+DXGB7fz=gt+RV!Jkt7~A)x;&!G9xomf*h=v|7K4 zSC6g#IRdLkzr)tI`0)m=Ko78u&k+0%g1;d6p9FtN@K*$XP4K@6t`Yoig8xJCe+mAE z;F|=}({C{(ll448Pq4c$Ah4bFy=}5OFqBKzV0`*T)_0wNepcEdkcL|G!V|Dhu%R#l z+wF(8hPLj94C0xzh>r&m|wEBj=Gx`;%Xey=eU}JdTr^nDU42dJgcj9rOU&XR_ z5xknR6&VL-ls6Tx>7q5OAErr(QiQNw#p<`7O=!DBe*q8iVCBA<%ONQ9^k(ZM-@5bZ)CcwMlTI*ShhEDN#Z4IWU zi^lhW*7Eb1&=~BhRmGQQ@I;ued6)(R6($EBgGh;bFv8vqhpi`$!9C5wVOj_uD@9FN zsg2HQh?N+ArVj^b&JbG#^dkCxWtL4EUh!rz-NVN~c>woBwiXi)c$&pf2ZT?4BL_K< zGa4Mxv8QJrKD&>-<(+A+uI%}aH78;-MI?f0Ol8{wyq&L4G~&yo0Z@YoyE4+eO43C&uhU<>tGb7(X*;n&VLOGj>)K`z*2Og?NN2zxkZ~ zqGp+?=;0HGPqaLjAB3XyMs{3GC|P_)ujC%qg?WDO;AhRaAY2uwuFggs(U4Al{7}}rwRL};z!lVs9QaQbVZw1J= zspU(|;s~R;lVu;s#QOGWd<|X3(YBaPdg43BV1U7Q(rNDuxu*ae>?VHHGHqn(mTN|w zj1h_9iDrz@JNXwWH$tIzI)}|z!_dM1_BR2qM@)R8#4=xq*yi(a`Tdp|F`qM|sc-%r DqjVov literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc b/tools/modules/diffusions/__pycache__/diffusion_ddim.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c86b6eb08497c3fd8d2fd8f2730db5686bde3edf GIT binary patch literal 28560 zcmeHw37lP5dEedMzRa69`y%OSBug6G8cV_$#}T%`!USW>j>Xcx$~;DM-i+Sp&3mJB z-$?RsUucj)13|8V00}rSL&-8-vOyq}q$N$7(3B==lBVhH7NCivqzg@28VuO||G)e8 z8Oef@A8vo1`Q2~MJ?GqW&-(4(o!L|>Zs6zfD<7O!zh@X<=1u2M3^&JcxnBWLhBC`W zKKN_qP25{$Yst>rrlD-*l%1teKC~3hhvhz0jx0s<(WO{EhI>PW%kiZ|KCzU{Czn$B zl;oM^biPkTRP=EppDEhw{Ts%s4HZ-I#|;%PW-r;7%=|zrKcSMyPd4%gTlp!KMt-`H zKh(MEvAa2V@Ai{(c+jIc`Bkt)!5^q{J1ht8q*V3F>2$) zwCTr=zvjOCUVrj651qJbB5T_6Q>Tly*@8QFrl?lRML+G97M9xgtZ=qmP^I%ly-=*o z75zP`t@+W1 zmTRSIWwtzLvoDPw2Y-3pi_5(az%v#N&ssFj+IrM8)=j22%sTScEoE+)^QN+pXDfTu z7&JzVy6qXBIcBUop1onR{ssr^tlsmQiYi{r)ru9js&gxDsj`r(N_qbz;`ODwB$SI zOT`P*jvo%*=9@L&yy%DAGqcOZO+yc%iD!=Ad)KMgyG8AuD!HY_>Z#XPO0TGt(2?S) zTD4ksPc2nB|L!SOnx98^tCfyqLE$ws%a8e1p_D*tZWO?<6!nUM}FV&=DUs`f6ap@AL&|}5T4`pS~l`903R+JOJ?0M4bQ1r zp1EkxnbR;zXv7Tj@*<$0ZBYB(bnLC*`8@WaAMQp)0Nq!Vxi#23x!}SAe}4}&aKRrgA%1hD=p*eND-!gSPZO2&Y-9QTMvs41vAh}Q>6;_e?sEVrC<0ekmj3RwSEq!?X~Z5U`F9{kRS)PNd%+yt43Z&Nm;hNUdgd}7{FBWe^aCwptFG4v&( z#?^#8k-A3tq}n0n>7Md2^kk>nC1rivl;zZJDa-Vf#nsq+SWT%tQrf?5>0Wh{lx7>F z2&*{WwomQHI0k~ed0QP&)3_h(Yyr79t6SvRA$6-d2pl%7;=o~t{8W3zfy4Y*xl}36 z>Vm6$ySVK7sWR>=pj3LcvQYHxCAau=|DPTF!5{tLxtHJfG}csp=Exp2f09UqaS7u4 zb*#8&EAzbJChs@iW_pgYkRFvZ&>hlynXUm12FQ-6upY0OUi^}&oU@L80EdI=&{@-c zSnAqdyyh&1E?I6Ff5UPYQQ<3AJ?_PUKBKi5pFk?^fxZi2;v3TQ!fwuz^kitm?5RK1Y)5J&wv_J?a+5d#l4o5!^i1Sp;Vrd; zhwJ)Qtg>&OM|PyKP4q!#CFV;Pv1Q7@^~L6my99m)_-lq0lDg>5&6ZII`^mS?mi{0T{lQZP+TfkE2Nb&*7j3=x>7BHc)B%R ztSr>d_>sAlrDa`JetNdNd}h`yH1F-XS#W-p;)O!BQgrhJ-NgkE_Va#>#j8bK-CABL zUaWxzpVKvefUwa0yuU*-(R2Y*v*Pv^+h9ulNphywtTHAo=;B;yx%cUbcG1$|-oos1 z6%(jSRb9a3fdS@9%#~-Cmi?XWBEGs=T$n4DmY0jlA8eNkLo!z_&L7@3_sC{4poBd{ z@@Q06a?>D~5igMy?)bx(-Hh z5gaKndB%$f=MOm`U?PE{8!Wo1 ziM%(%`qlX3$AR+|L=%W5!TgXF_Ct~zRE!4mdqnbOQ~Ty=-<3-TYyi%$3X)nJY*>peK-M zthxU>3~Psgu&3+`O(X<<1?rSg7T^+{MD7=|G&JNVDna zDjFAgLH`Y=*l*+tyTwV+S%O7^HnS9Px918$@5vkmJePw_V`iS~-43maHflD1NKZLR z8R1bRKrf9&NNdeeVVqzoInq(=6z0|)snZx~JrwMRc}Oms*6P0FkSTTP^a`bdTy_2$ zBi66Pt1v)HNWq9tGWEBZ8}1plZ!N9(HjXQZL2zz+=b-OgK~8M*NS{SB4M~hA1na${ zd%b(=cddB@54YBwXSOjT7;j^pIYSmk&3j9Bd~-%5uO5;ocur8G4v`D|L0IO;T)l0( z`4OTgyP|e$@74$fb2N*N=+gvmB6u^vW~4n)AfL5hk6d8!MS{l&dUng$xk3+?YtEUA z+KoBeiy6b5?eCqlYdqP-qmjUoXOsQ(aU>cihVv!!;6*ngr<{7M9ixA2FUV>fo+aXvNfT1DAz8xemMo&4UF|R|MU9 zpBLlv;MMy0SJo!x#k@4$6IEd^F$lhA(C|`HlJPQLa-O&1=~=g^atkXZMdFh?z*JEkR84_)OKVaAdLFPX zTN<0!#bRe$Wwdph0>e{0c%-ADfOUBVrdv}BAxzAP2E+14_iV=ezUA%{B6wIYpssIK zKp}-A(HTD5$IML7pAOc-E?!*T><@O_53n9Fk{^OZTrB%Z`BP{J;RTk&my3F?SgDnu zXi9cYn4jnvuW#(icK zIDQcNmtF+ed|*zLG<*O*DO{X0*|0oYp1BT$hbDmc!iR#-n#qh*%7H z(RqjxNQJ%FVnkV2oO%LVGK%e%0HSPe_sB-{l!fh{JeTug*lPCb1J~PXxyIf*QWtz6 zyKp);GczLvha~fG@5*ml`fl_*Ae}Ky6-Pj5!_coly3!9I$9Lwz=m4&i%f15!azX0r z_n>Su6)da}n0NB=cM)(Q^$!ueo8SroajWJA0RIP0H9$HQX7-?(RCBsPh^xzd)+G;M zgyN=S4`YdTnyVAtD5mR)u2o{&EvBUg0tQ52f@zt;JQERsG<6|@Xyx}oj-d;?Wa&F= zSPO`G)>+FvD&pamXOul=hBPwVpI}&WkR_ zz=TN&c32e1G_oG|l0>GJKaiSHsrl|LHPbch)Wu|jocGnx>-7|J(hwX25})zX${B(1 zH$kME_A(Ic94`$PpifX|=eri=U2%EWZ%JG1XCHcIiQtNMk;YgW!fo(|l)N*A-X<#t zq=kgE@KxzsU!xzXuAU}QFQbyEXX!hjxXgv1Pk*f$|}QXDdG| z&!^@2Q+OV=AckkLgRBkvlwI$y57Y-L?~+=5QtM;V+F+w~GI?x$sBs?++TFI5_v43uBj?Pulunr0lZ}QVvqiS>(A59c?e8y*6a6D zUee3XJ5cImpSii0eYrNNJaN^+p=}G*ATku*^*zjeDZxR229Yw3={uP9GJ=;7kU`W! z4sK=Yb^;Q(>70HC)3iwdjgi}YE$<)XeH^m4uDKUVwKHujs9(pT(A=4o%2|Dxb%pDE zGgEm2GI;@8MZKkNvs!QE?J0ss30j=FABFNaAZ?K4XKS!7`1b5+*AI)NSqKQ}aX#}Z z0^-r>p&r_)x5KkUlfDh>haY;ZAb(c)M;~J|!qAbm^P}^6R`l+E0s^+!c*q|opo3`M z1e8@Kkl5Kb-5vz3FCqKGxLjg=!-<;X&bT>=-@T?|-fd=|wJht|a3~25n(`?5=s{8E z^6qBI1D^<6hlBj#=faRx9dj7@6h0iQof-sH9yM=8%kgb^@gC0PDrLbrQsm|jSF6QR zK4FrjgefXUY>Sfzkrn5tNo@n>GFj(uN@#?se;R5&Q1IZi)zfn4G|rwdgQJGzg@Qbs zhd!?l!Y#1fdKf2aTUJv}eqV1}o$U})0_&uD>v*ZH?~^*s9%fp13G^VUzMGNZ&LL|1mg%xz1b&BlpLA>e~7@^Mh5&`31udS- zPt(fS#f#TwVTFU7t%_QBFP?KhB)TP0o@&L6L_;k{b)KGJ7kJeDm?}ZBM)^fQ!OSNK zdbqcQyfV-%{2|I+Vqqg3w|AOROArB$_yJpvx&yR;yEG2@J{bO?ajs6 zR^x{PV+xn>-|t1D#eZWL*AvPj2QCT-DBN+1qvOVgJp!|GkjEXnVSy_HZ(dJ=$Fj@e zrEnd#f6}c>C#=j4E2?*smbgIg!J(ZnDTsP1DMad1?Ldgk22 zUIKJ4bfQhhpb~a!=hTvZANsnP=pZxiM@FzoWCwhJx1S;SAi<9ikO9`82AFP0Qe2O1 z2|;V|_0O@&&l9v+?@*v5mexPSqAbC7TzIBq-h~nWI%)>e5|+xbX9SOrnp4*5&TeMB z_rWePpF&~_Cx09_f|Ic}b(kP}aB@Th{x}4Fi2Nf2X2SrJ=543z_GYb3Z^3enh;E7uRgdHQ{8en z5VdO44QB*ms$9bkoqOo`t8xca5h`U6S}+@>xlZ8=f=IyOU?V)hG-+yp9@?1U)y#Vx zfuIbkTm#99CEit;|5Jz~3X;Tw% zI)Ov!h&atu2;ww|9F8a0eQ0=MU_U#KXG5G$ojVYS(~YAkr!l(iY3wy)$yzxNKgwF58!#%c09*7@O%<60OHChc8DiM`5gmQ5^RPFTNIqa?+v#4l*)L zpq9P}Oo*wzWC8c3Aobx)dDdD4Jcv!17q$CP-Uq2Lj8mGaytnZzRL zKBG}%xlseMq5gtG>yRRGx><+3kpgd|(oCbxK56sr#uL%AA;`FK1qB_R>F;2km(pf3MW5F#q0F&GdIGcJp90eCtvV9N|*OX7<+(tcz}E13jk~3b>L#p zT(j4lHDfKbW(oGS>io6W;A{l->Bnm(_6)G6XNWqVccK%4LK~~$!G%j|I6)1}Y>4uq z2UYBPXv0{;b9fIp#z5z?=1_Mc+-^vo>u9U99WNwxU^uyv_tBgX@MBPPxACgr&^B@< z4i&ZwW`i~kZ6nA}viw#6G$LfVO<3FV*AKChpkYCr#EP&CSC`;|RxH5WdOJ_=BytUF zr+qE&Wv=ql2vD6G+sIH3McCa}6ehH5eJ> z9xKLst`Q^)OY$(+YbMr3w394_!3NY56cVcl8i=}etm8Ux0r>qW`2Co&!S6e3)?&Pi z-w*LQI_uG>;X(TY(?!^W1_~?D$pLjR#fe^8C#%zDfMs>IVSuq1T$6xp5hVNzy!{jb zM=vbtrX8%wCzv5?+g_b`b0RMux}S?9>R%X4A$U3h)3Rt|4@@2%V zR|$flQWCg2LcrHt{jXRIQ)b#)|MOER)u%!xf zvv7iwExARjzAZ&+?OQ)BJFD#~@SAM-w+UoUg`4M`=A(QMnS#8$aaqH=K=}O4*o~Zp z_PhrNqGLnf!a2|39h3-0hcl@-%+-_%()bI(-`Z!UtwR*{d!&GlHoEx9E-B!r&~A%; zh(QX#KK&M?fW!iLNG}1_;!WTUML!SDpDD-yi;$OL9oL(i`|08prt+6Q9b`{28&H=kHf0Am9>9%D zzYeuXMPkAFY=pp*KzfjJ8{EXhbjMnBfZO6v+Lgdz!8vb0FMJA;9Ayw_bGj@7Nfm-d z$N*~@M*Dx_C1JCHSGbB@A-|J^j0k-?*Blz3G&Jr~@=nS%jecA<;Nl!tiN~WVdBqVP zN%{j#3SjGE}!F`^i7IB22VQ}*P&&`*_l3Jzt9 znKA01qNLT}fFyK-a-7ptNWVASEl_TWvKux-z`Niu!1C4Xg&o5z3!dwwA zf^~R=<4{AQ|E3$ssEUli#SJ=al*!1)8Y7Q2W@5Z?4~`GiEu)6L%;E&b6}w_~_fw6G zK=W0HKFsR}m7;52ihnE5aHMJ!`@q(3xRHH;zWl|vxf8~&wG*Vx=iLix{3+@)WTk7- z^~qq~Rb+rWZpR2HH^vprtgVl>W;WhE6Ya4R7fnx`fL}b31q_1UgFpJNCIJ1JVfs!2 z5-NR^;1B@nzLe>M1fuj3%6AWL=Ww-tsLcmn2z+~q^&%eiX;qg zm(^k=&?*V}h#rL#)@e0+_j}ODU!s<=y0@D?(_g2}DK)EOL}b+bSj&&^X+CTm`#Z;k zVq;{2(}at72@1qiFr$Bm$uASga9QW+Z*2aL-+l1(!}mOW@ZQNk|I!`f_n`5Yqu}|+ zssAfm5|sJhnEFcVmV?uO0q`prm*&90w{-Vtb7$b*h1zO4@dbz-utJQgP3%6pM(0h`Cqhwr@}5!NED z1XC`vEfXEZB-e4dXV3)VEUXM@(y-VSlIE!GAeF`aPEe_I$TBV1BBs9KILyHxrAi%Vx5jcSJMqev& zN(qixJlJrTSdUVtY@eViTs2^C%QKU5OyGJDx&ZdJxS8qLJ8k)_EqWFB0OhZ|y|i*z zwg-2HxMv<_-jUYr5y5G`U4j3C&3)6&#^`UNmEYms;BpuU?U;mzvFAQ8uPFD`c(Ew`A`@Pff#Wj$rrBip!$e7+tPO;#9;CsQ=8M1=7qYXXg! zpm(yMWO3qH7Mil4@WvKsusmRpBEpz7P2g}V0h<67rw4IQ49Y3gMIN>V{i4VcK;1#s z6n7!HOR?^H9|(O$g+b{184qzq*jLnvN*ySAu2tvUfgn)Mi>zlijN`^76P})A4$$ij z3yd<@WEkNGv8DPk3qxMthADqxG{Q#C!j#~At$79&57WiZ2!C(cexh{4fj!>yqTypg)pWwmsCGBBr)S{hxJkY5-Y1q@DjsOmVbsf}X)LexZMW@#HdX0B!iO7QS z5`;v36nI4g;g|@o+`xVVY5@N`fsN@Q)RNk>uc8jvk%#doFhP>Pz!schhLG~^Ehnr; z7_V>z55nghn6t+&)aw~IZ6bX2#O8?J$42)Pa593Rxc0V?MHZe#_J5>OEzq}5+|Trf z`M@QDPZRtS!RH8mjo^y}e@Z~|FL++x#MC>O*B*hSgiW{^1=PeyYXWcO;eIftw3CyBdbxZIDRKQ|D7 zcDNOQcI2-s0FB{i2sIl;*qel!Z4-z#38|`CE1`C%of5#cH6V@k*v>AMlN!AtX_E*N z+pVUgtT!Z$WqZ_KDcc&7Hi?k5o76rj-5Qd{(*5dyly--tO{<%`Leg$gw{8=Xc2Ld8 zvqK2;JJbqEdx_uQb?zRZ3*;?WyZ{%7{6l1~Zv(hbWTP%k2G^R$;UJp=&ek-EPd4xiD>cu1ck&;&i*B){2{&r{YHSuEn7GBHN#7~Fe$bDFc?+B(kv@AP z5w_WrO=tHmO+pH-Byh>hya|bx$k}qE0FOt{J+b%w-$o`$hY5a^pl3{E%O2yH zNcy@Q_jhew`sR3qC%&4)1A)}@Wu)e$0VMcCMko0TyR!JD`n4=srB= zqwrcr`2LkJgw7N^_+vGPdJtTZb{&Q>L)3(AX2gJh0>t0o>35o!`9TtX^(9)1!d@Nq z($p#=mA+(Cko8h3b`(6bikD;5iJ)a@k&&NZoAnHANd2{_mk|%LDf&ULC!pcaiU1C8 zN{Fct#p%;7wxIP|5+3yQy&q_$25U+2e5a*V#za5td06BZ`*3Ac0v3;xjU8whgt`_^ZZ4q=IH9yx?bD*=P8hk3${D=bn z4m}lGu!R63o{`eK+>_xVIZ;VM^j#hSiwVT;Ec2PmL{2!h4DS5PNGAcyb4FcG9;&@^13> zaE{eD#gwrkw3|~qq~87BM15avLhVG}ZpOE;Rc{|`WgOd7 zXDZ_v8;Bhtz<(#pP`R`x%CE3k0IVdD&Ig%kqLks3vc@hxF@QsP^ZiwC^@ z>(e|>>o>zT%lMWN#JAj{u$t?)LP6P#Z<)r7+aBRTZyNC{`z5|*8Y{eC;#-0@d1!5> z72mSo+viPryS@Dq-!i420!ka}efE|7m%6~ZU7{kVUL{-?Pz|hi+L8suG{0K#yABRB9 z2fhvw)F77*{UF0%L_hc;Qfac=Vas}M^`+ZU=dK9p4%J|*4RLlkF>M82%6Fr&E&ZqX zEB&8BLi#^mNAEk$(=aXoK{rOzvYItUqsw@1+U)V>(J#e7KZy~DhSN`ovs3|};94kN zy9eYVqz6+$4`5sd!sb^AKq>@w_GH?YRn;%Un*xiaz5_|v|K5dm+H)#h`An-TA?XC{ zZWj>yytOkFJ7hZUYIf$xwK}5=1~kL_5*!^V`iRkVj|r$T)w|I%KO_qB_6dbA8Bjw}<7;(6m>WJ!#+T0G zGc@>2Sy&gBi?g*3i}D7WqEsK)mM=3kLhxpSJi+}0ls$qkf6X#=nt)PZ+e+MFcK*98 zpd6=voZu%2CeY2F02Vc#hWgy?*P8U`xsHE@_ z9hl}ZzxhZBACkkj3z;KSAfH9>ABK-G?}GkG4AD0i z1rh3>;&?X+juFrX5q!Gg?=kiF3H|}W1p?V(VypOvyyb}lRul9^|I`4d$`QPTqRZI( z?tetWu$`9fI~t^(ah`MHmN^AqN$7Q~?6dJu)XE}wj{$s^nSIWBE(}vf9Cyrhz8MZ@ z?NJ!bErk6Yz;*N2ooGv+{&nYTP9FpOG6?azJ%kU_z}%5B@5JxzRusCUj$p#@cRZMI z8x=}VK%xF=Ou}3rr=syQfxjtS_u_ItfVl%J69g18ASiH1YlZ_MSO~5;a3tSIz+9h^ z-+l-&2$t2KlE4)kZc7Y9fnz$_=?pAHn4aXd0}BTrv{)X#(Dgilg>Y=|ZmY8$2H=6+ zgHwEW9fk#V)+5_4_st#-w(jWS)ebBiY6TYJi(ly9U?Z?lSpHUE;czRkaHJ7fh}zz8 zE3go~s*m1KVBweq7UFAR4A=HXA>@pC!$^$?>pkv`d;N4!&_9Fq(LW0?J#i!0XW<`)UBBM1p@An;A=adV(JNkCkcek36PWaY2?Mh zAGV_7@bQToi;mkyiz>k%fhu(_jfR^)d>u7rvR%obwsb;i_{Jhq6()D9Be*7Uod~%M zG#){O9IT;;?GE%C_hIwWJ2nk>1*W22Un4?}et`IxiZ{Sd*dP*?U$+oH6>NGQ5yXyO zNkcm8s}CYpt`8>ra7V11{zddwEbG(oKjYlD@OrQm{|M#H!01KR?WwdKk=SmTWnIs* z0UA(PGOyz3`&N^sitL`kxZ~DnZX+ ze+_vLbFgyy;1GJAY5wrwMsT6o7lXBKG+0ajGt5?(=L3ZW zo-L8KzJSF4PjOEA*D)ad8w8&x_~!(&H-*UjCT{1sEhmwY%Zp?}5H?-KkT!S56NYl1%@_%{T9Nbqk7e1d;R@b3x8mMpuO6x3pODRQ*(*!-}`cIMfVNRIPCPCLT5H`9Z zRkoqw9n-XxhLcr@;zF}8Myr7LV|>($(JGAQcrjY#J0Gp0ui}{LKO?w{;LiyT5&Rc| zw)UNBxH(*{A80qW2BX|SkjAg#U-S{Sk@hoRMU^fQg+}jau zFuiqqUBk`y#gK!}u!Qe^$brssv2huCfT{N}CGi5^@xTB*#1{+`j1u%PI1|V_;2{4m z(S_Al-tdkRwx$D-bvys-Nxo_a0l!M3_Ym~FncmfXc=NaB*53UA1i=K?*z*D7MI_$* zVVaj<2nsqNxz}`U*ZbJ#{R9WL>HRzK`0eaHrA8xgV&8@l+}@F0Pj|Q2KJW~B??)fq z??bCM;#5BVR>KLKQ*bMv-0D{TP{Xa9t#;YQ(DC*cxewsRgJtu>4(-IWv%=L2VH#7|!g8qKA+d&ZTn~sR? z4gW*{HkjaRRbr*+R`B5~8=w{Zj}G`m>>S0<<_~o!6!`bxaz78S2HEbcsmJLkwgv<0 zCG!p70QmnSpw)J2b`76q#T90~mRlIite7FV%Aq9U!G+*%L@_;j5iwgf+`h{W|B1oUA{xqb;jdymrk`Y0EN z4T&$nO$}QVc6;!T#B>K{XxVe8<9Z!W=zH1Rai+LY^f7{05WJG${V0HmQX*Bk;CMParQWe#*G-alVe%tB#cv8d{6@BD8TqP zqhT7VG;2>sW=_nUXb1S-2lK$I@vIpA{6TzqzU~YQgt+@v^H=`fIPF{X2n354@q=R=Ld`jNohpVR- z@mW1R%15RNZYJOXm*XK*>N}_LPsL;$9fz#7fPVBSOosS(V;rdAo&bOejsFD#{|90a lDTb_MW4I!?-fHl-(n`J6n2E-)*V?~dqnR$;xRCzZ{{e*lQ6B&R literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc b/tools/modules/diffusions/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6267e7431b9170ab88bb6da345c62f816785f120 GIT binary patch literal 1337 zcmZ8g&5ImG6tA!D>6zUPyC!Z1gfQr7kOZ?x4x(0WDgDlO>b9CPtSCBb*j2& zC!yz{20Y~)yzO2D{{oM>dh%|aL$|Q4DJ>#%7qx?pcJMth+Ys`%Ir?CHVwJJu)TT2%KGON*(aERb zSz1VaX2VjODmu;7W@8-Jw$gc6XsPVR7@idAgCfnNN{uV6Gds?;tTS+s>10ydR2O!f zX=|0eQ$8}7D*|CrxJP3OY(pq-!E8Y~00BWla4L4rC9l{_yaix1L<+Eaz`YB$h6aIa z0rDSz@nA5CY&D3=(&!~71}O~mWKd1jAg=Q|1It1UpoEo4PwRGjQWx=6KU+nGjEo!< z+T>Aor#ApHSjAWPzu4eU?k6wlC;6|xe|LDvVSnfwe4i6=!qf@49#;F%RF0LY7|-7j0)>$+SU-Z$kQsE-j0nQXoSAD4 zn*q0bpOLR9CAdOoLAR}VMbBBou&-94;*DsyY~lNwuDp%j2RK3Z!B;fi%){UEltM%Y zEmAaopIos9jfa+17j}It%$zjniLyTzP0*gx{V(EzU+`70=|RqrH{^NRX4r!s?;Iz) zEtQAK-k-O>{Px+yj{q~Cr-iu&4^7_vX8-4(UjFrbBODGp&a=~~R4w$T15Pbqjzj-D zFV3R8bl%xi88zH;Z22D4MMZ_{;*Yg)lqtvJo~{#yQI>^a`-j^1wMSh!9)GoI+bQ^P{R=5flEUQm+n1PNCM%~#GN4!$a~VQbL0&(^WK{`@8dUbqMe-q0ebx8 z>EyW}QMNct5Sc_(m_ z9NL|E7Rq#~bfWTD*-vC@Vy&w5k&9i

m#6??L1vX zl^Rz?k(u$LkaecaSf-OnZPFq)%i_Dv#v;lpe@yX^!06j5}C@6%2F3AObk-+i-})Nl^@rOIs?s8 z`S2K(PVd+4@ube5m#VgWdQ!qvS664q(HsS@Vh>v4&xPtVP8e(Qw&<_dQ)Z z8+iy|g6x8=Xq=gYUH^o_jUH;GXxt$=qxBjGHS0c{y7(+}(x4^s?p!op>!<(Ujbnby z*MnvN8AH~P<VY1s&d62yS=ap}NeEQ&XK#VWaTyH~}$>kp|yngleuNNEU zVAQjYnMS2*f!94SYT>dR`q(;g7A;EaJe(@6Mq8GxK806NQEvPA^GqCN%CdMBcxk>2 zO(mCs2iS*YR-2%Q97GPb(e|DCLzIeK;Zbc`JsS1VK&q^YI^WQL_s#zvTH_nO2h&`F zg;1Y*berzN_<&vjUQm}Zdg&h_cIjy=>2qLcdtn%5Ss1oIsI9Mcb;dpZPTe+>ceDG0 QU*L~2sQ0MHU~td*7pzY^!~g&Q literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc b/tools/modules/diffusions/__pycache__/schedules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c052b4cc3a5d411f353da7e7449b4e694e9c84a1 GIT binary patch literal 4726 zcmZ`-&2t<_6`!83on3ud(ngjo8?Q~iHi|!E0!gJT6N5v5=->|sLaAZ2J(5P=-C1{! z@CUO%l}sg7AaV#*TsTNoTy)_tpbGv14&3NdfP+zSfSjmuQW3w`Gdt3*l~K)`?l(Q# zuiyK<-|JR9GEy;cJ=*;3a%IXe{z;vKmxaz%-15&5+~BNfEM!l%z{+OXD%<=lcewkA zEto~W#69#aUgniY#)4fQDi4=O%A@=oALB>x#^K|90&iSC$*Xvl_<26XkK&ETr};5x z%KSK=!L!0o@RN8B@eBM*{1iqE^NYO3Pop)$FY)L2^JtCIuhT6275)N$5t=bx=P%*A zBYc*>jORFKH;ww`U97^GS8gMgQ9IoI7nXOuL2K+?4jOcS2e*6^A~6z{n2E*N@9fk# zW~6MHMMi4kV(LhJ3vls)?S0`?Am#KeVqK z%a*vWu!1G2tX6Pe9TTB!1kKQY7>d|WLeYxaLDQFQQFoMA{OU-1z2zrSE0jstk!t*@ z(aODZFAytIIZI&@$Q?sG1N)z>UweJ=whV<_jAXPHFJ6mV%MWh9?ca;qJiaIWPArmO zsTuZOEH>6<61V(z%tP5z`nTKB*V|DmNW#S=j+=6^71R3VB9E4r*JTv9DCpMVUqGW5{=^0VoI?PB zu;0NCrFPe8+liyY?f@3EfPk}A0xT>I3mdSQO}x|rES%HN@xf|el27>y1wO|23M529XHClC9yaDXV&P)cqwL}D01jp;!7N$ksLr4>jqLa#c~n3Qak&qBc0+UtXU_ew7F%xR{*PvQAu z?_b7S&QfE>KpL@5IkmZ2HM%S{J7Z{DXgi6q#!_ff#A(gs_9ms9GrxjQL0cbQ)vj9k zN%aqJ{r%x@=60BJc=DhV?pVr6V$oPt9-bizGA6(kQd{1nK&MNUJMSsD@4!B+e#{)YLKymbr-_P8ymYbL#tePMpLC z;&~FMAa?8wi(e)M;p$Mc2;^6AOY&}wY0!VMNf$5(cmknoES0#~HFFbT4ooCj;E6lh zM0bAW$xqk*n0)l^+{*0xQ-Ax*|2~=%7x7C228&Ci8*m%tAC(gAHV@c`jih`0+0G|_ zeP`+Ca}6@QcWKa|LnGu8#EQ|i5}35XzG1v~hH8pV3Q!GaN1;c8xbGXwjKlOtUbmE5 zTc}vvPD`n~!5{&e8DrYu&W5?om~q!ltTj8eAW^qBO+=tHzsoRFgGBG5HbIjpK<9K~ zYypZN7)OzoBxkgfsvy7_V;i$;M_KcN5TZj^x0DS6UsV>^O_`mj&JCU~;jFEft3L4pH&4_tO{*%F{?kYzF?UhiuL8lY8 zIjUaN7B&2&a+9zvVM^W7$xmGhLX4Gv6M z6((&n%IiXB<~lpEwovN0l^WYt*WuK6aJwi1$oUl8i7K*>zDEgkQx{?Al|19)KkWl( z_j8IzNqhx!D?3Pnwz9(e9UUKW8cOk9dSl@QHkzz0&XbPl5!KPBdO#192nO8cYaaB< zXn)07!^b*iRaCrD@SM%W_n~W$n|c?;iwcBx3`K;w;<`vVne2uSuI>n_b&%ggqt7!=0uSkp zo3&wg-J)lshkBK|s46UVrj4{jSnqnN2Osomi)P1+7Bd`^=j-J`Ae@C(&^}DuhFBQN zAxxe*(7{&c5U_}s@yX9{Yc!8Kqd;?&7E1a)1|nmMg-)Q;OX(C9d>cs$k=fn!ILzlrmGPW~WPuoK?lZB`hl98l>XqU|D) z5hwTI`AQAfjqI3NyO~$-S`yb5aOAAr%8#72oAWpH$bHfhZ(%hW;5Z$L>(tR6a~jRK z9cGZH;)_i!6%mwpT~Smed-_W_Vvp^&eO)!j2NeDV^_1$?0PyvJ=tE2)=}_Z26)N)< zioJQn^sH%c1vcP;GB2PbUL;Y6K%ifthj~C{?YDawO|@H4tq?z? z(e}UyK|flb$k?w&!{`7jXj_pT_h>UdaOTr{09`nBiOT4c5=Q5&#z>% literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc b/tools/modules/diffusions/__pycache__/schedules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e87214ce4b29cf9cd9d02825ba07ca975b137a1 GIT binary patch literal 4653 zcmZ`+OK%(36~1>~9MO7MHY3ZHlhAJSU>rZ=q)lL0ZX7s$MT_{+lu*i!cqxf8MbdXh zjpGj60&?L1)L!2_uUzeL|PeQ&YXMZ z4(HtSo%5YL!Pr<);JE+bcZZHKL05Mb%Mf)r2|-i>oHpAy_yc9l;%6O{=4@jHqL32G^oGuAap8 zfI6?9Qcq*XsJfua>IAeg^^$rQ#FoAs zIkB_hZVF{hi*>2&+YW4A?5y*?gSR?!6({jpcwwNEk@Q*%RWfpCEl@g$nssgAO%HM6 z_|NRa(^rrG>+i1|o2?X*iT?EZazm}P8p)x)LDjOtF@@P+6bdYH%ulUo1J*K-wX6on79j# zCA}F$09l372QqR1 zrMnN50Pwb@&%{p8ZTpdH0CWNHEr9866acVo0CoWIt;mmE0PMyt0CwjPxKgXLgb>!U zxHsDfY*({JG?$?1jKNt?Q!7yT5S5YS$^9P91E5DZm8GZ;P>ez(6DH)fs8b7@ORGVs z$9QX;VgjP#>4VfKDGpH_rYJ!qu6e>>l?v3hNXCsde<^(ozYZb1Eb7u4jC;YCp5;pu zeZuqA(Z7hhO2%SFpde)TQ0yqHBziKox=%s3p}Ub-kufYWa9XjH^MHlp&Mnacapj{c zCR9s5E&bumzuo)o?3PSi6}{hWY}tt$b##3>@lh9HH>fv~!eXlvM6X;#47mJcq}gsp zHO33Y7lf&5;(<`G4yt}`WdiR8#YFxLOd)|1(wC$78S6GHcAMyKf3~x;v(2E_iF5xT z+{D=AALGvai$NSQVh~&F$e!zBOU_lSWHN<6+(k2aqPAa3|D}RhWerI;+)CW_4{qz4mp4I2^PeTst3$zV|jrGT-kcJKU;7e1J55L&@ z>@ROE{9?9l9LEeEOd6hiDqMhA5-YU$ zA+|Trdz2FwVsBkSB57vCv{3H4wJD{zGaA_|PHaP>*FLa-Kw)lM;!Smm{z09oNeVDE z9f=L3;;+OJlqKbi2~r6JI3qUkb`vOjPBTLcgcUn+K;X-X4R%YcZnGj2xrQqA`(0+Y zXcGTYD_C7n!R5E`_LVazut2?S94JQxh*)cr#L%zwRlhhm!xWg?AQ_Hx+iXp!exM}bz-2%XpTe4qT z+_EU>nT9S#ZAL61g)rNSnKpfgS+CpRvz>~Sx@*^i2tu?W^>DRV4QQs;Z;;!xNwWX*Yqkza$<7 z{W9L2I6)M&6T5M@Yw*!0VAL$M#KsM57-d_ZrH$lRb)y^Kewg2m8$m z9?mc;$&wtGWvOq$RySdS!h?9RJ($1{5mKA!Z8oV5Stt=CE4GPSNqmFdA6;F4A5&2c z{6DiM<9@}}*J;Ro*UxZxlVXsBlc_;pYvFQpF={A{sTnVi$|q9RYS|*uWagKtf_enCrj{IJ&WyA{nFI^a{+3xWLwfdFD!f zMUE8W!k$7SGi21d2gqx%g3cHTbarm+h+nP(Lt8I*s*Vm;yVV1@Rcp3?QlserHun92 zDTCptP23K&4#Iqc_a5F6aw{nSFSj)%c}(gbVLY&nVdm$B`8oyzk#QqL$ZXr}b*Ja$ zBL^dNtJp(Nk+C~1;sSHM=f^%`(C;jU9aCCN^O%rTMurpNG_0DtFnt4}dLT<-li9li zY;^WX7X2a~`8m!so856Nb{tDJ_5lW>GxJcqwx(f9bC=K(q{n$8rO7j;sMaZ^%vT8)935wl^&3;0|SHx=S|s zrn!$?L^hHpUrpyb$VWVnJa!HFTlgZ+$RETCcEUZ}O&bH-0b2)&wueFnPTsw<#WIe! z)6d27^{jiBqfWVsPsj3&?9;J)eeP{DbCd>iREIWQ4(>#a_^k*552V8xb} zO$068bd+i_z19{2U7gs~YwViU$st4S+~cMjLy2!bhd#n9LVnT+*Ox`<+emoJw+`C; zewiZebN>!#*@v_Zacywmq%G$;-hQ5<0s+)t;>FryvUl}bB~dJ`Eg%9^!aBt1g49cewA`5tH%_uNp8(#Gt1f&5W}SAbI-q&euZCV v-iO>lU^8nAiQkvS!`XGp_l+jS=f$G$_+D`m()Zn>Uv!Hj*}X4{qrUTh`Ek#) literal 0 HcmV?d00001 diff --git a/tools/modules/diffusions/diffusion_ddim.py b/tools/modules/diffusions/diffusion_ddim.py new file mode 100644 index 0000000..43a17d3 --- /dev/null +++ b/tools/modules/diffusions/diffusion_ddim.py @@ -0,0 +1,1121 @@ +import torch +import math + +from ....utils.registry_class import DIFFUSION +from .schedules import beta_schedule, sigma_schedule +from .losses import kl_divergence, discretized_gaussian_log_likelihood +# from .dpm_solver import NoiseScheduleVP, model_wrapper_guided_diffusion, model_wrapper, DPM_Solver +from typing import Callable, List, Optional +import numpy as np + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + if tensor.device != x.device: + tensor = tensor.to(x.device) + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + +@DIFFUSION.register_class() +class DiffusionDDIMSR(object): + def __init__(self, reverse_diffusion, forward_diffusion, **kwargs): + from .diffusion_gauss import GaussianDiffusion + self.reverse_diffusion = GaussianDiffusion(sigmas=sigma_schedule(reverse_diffusion.schedule, **reverse_diffusion.schedule_param), + prediction_type=reverse_diffusion.mean_type) + self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param), + prediction_type=forward_diffusion.mean_type) + + +@DIFFUSION.register_class() +class DiffusionDPM(object): + def __init__(self, forward_diffusion, **kwargs): + from .diffusion_gauss import GaussianDiffusion + self.forward_diffusion = GaussianDiffusion(sigmas=sigma_schedule(forward_diffusion.schedule, **forward_diffusion.schedule_param), + prediction_type=forward_diffusion.mean_type) + + +@DIFFUSION.register_class() +class DiffusionDDIM(object): + def __init__(self, + schedule='linear_sd', + schedule_param={}, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon = 1e-12, + rescale_timesteps=False, + noise_strength=0.0, + **kwargs): + + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small'] + assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] + + betas = beta_schedule(schedule, **schedule_param) + assert min(betas) > 0 and max(betas) <= 1 + + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # eps + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + self.noise_strength = noise_strength # 0.0 + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _= x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2 + out = torch.cat([ + u_out[:, :dim] + guide_scale * (y_out[:, :dim] - u_out[:, :dim]), + y_out[:, dim:]], dim=1) # guide_scale=9.0 + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + from tqdm import tqdm + for step in tqdm(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta) + # from ipdb import set_trace; set_trace() + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None): + + # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32] + noise = self.sample_loss(x0, noise) + + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], + 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type] + if loss_mask is not None: + loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same) + loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w + # use masked diffusion + loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + else: + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss*weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + # # derive xt_1, set eta=0 as ddim + # alphas_prev = _i(self.alphas_cumprod, (t - 1).clamp(0), xt) + # direction = torch.sqrt(1 - alphas_prev) * out + # xt_1 = torch.sqrt(alphas_prev) * x0_ + direction + + # ncfhw, std on f + div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4) + # print(div_loss,loss) + loss = loss+div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss*weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + # noise = torch.randn_like(x0) + noise = self.sample_loss(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t + #return t.float() + + + + + + +@DIFFUSION.register_class() +class DiffusionDDIMLong(object): + def __init__(self, + schedule='linear_sd', + schedule_param={}, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + epsilon = 1e-12, + rescale_timesteps=False, + noise_strength=0.0, + **kwargs): + + assert mean_type in ['x0', 'x_{t-1}', 'eps', 'v'] + assert var_type in ['learned', 'learned_range', 'fixed_large', 'fixed_small'] + assert loss_type in ['mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1','charbonnier'] + + betas = beta_schedule(schedule, **schedule_param) + assert min(betas) > 0 and max(betas) <= 1 + + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type # v + self.var_type = var_type # 'fixed_small' + self.loss_type = loss_type # mse + self.epsilon = epsilon # 1e-12 + self.rescale_timesteps = rescale_timesteps # False + self.noise_strength = noise_strength + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat([alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat([self.alphas_cumprod[1:], alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + + def sample_loss(self, x0, noise=None): + if noise is None: + noise = torch.randn_like(x0) + if self.noise_strength > 0: + b, c, f, _, _= x0.shape + offset_noise = torch.randn(b, c, f, 1, 1, device=x0.device) + noise = noise + self.noise_strength * offset_noise + return noise + + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + # noise = torch.randn_like(x0) if noise is None else noise + noise = self.sample_loss(x0, noise) + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + \ + _i(self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i(self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) # no noise when t == 0 + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1): + r"""Distribution of p(x_{t-1} | x_t). + """ + noise = xt + context_queue = list( + context_scheduler( + 0, + 31, + noise.shape[2], + context_size=context_size, + context_stride=1, + context_overlap=4, + ) + ) + context_step = min( + context_stride, int(np.ceil(np.log2(noise.shape[2] / context_size))) + 1 + ) + # replace the final segment to improve temporal consistency + num_frames = noise.shape[2] + context_queue[-1] = [ + e % num_frames + for e in range(num_frames - context_size * context_step, num_frames, context_step) + ] + + import math + # context_batch_size = 1 + num_context_batches = math.ceil(len(context_queue) / context_batch_size) + global_context = [] + for i in range(num_context_batches): + global_context.append( + context_queue[ + i * context_batch_size : (i + 1) * context_batch_size + ] + ) + noise_pred = torch.zeros_like(noise) + noise_pred_uncond = torch.zeros_like(noise) + counter = torch.zeros( + (1, 1, xt.shape[2], 1, 1), + device=xt.device, + dtype=xt.dtype, + ) + + for i_index, context in enumerate(global_context): + + + latent_model_input = torch.cat([xt[:, :, c] for c in context]) + bs_context = len(context) + + model_kwargs_new = [{ + 'y': None, + "local_image": None if not model_kwargs[0].__contains__('local_image') else torch.cat([model_kwargs[0]["local_image"][:, :, c] for c in context]), + 'image': None if not model_kwargs[0].__contains__('image') else model_kwargs[0]["image"].repeat(bs_context, 1, 1), + 'dwpose': None if not model_kwargs[0].__contains__('dwpose') else torch.cat([model_kwargs[0]["dwpose"][:, :, [0]+[ii+1 for ii in c]] for c in context]), + 'randomref': None if not model_kwargs[0].__contains__('randomref') else torch.cat([model_kwargs[0]["randomref"][:, :, c] for c in context]), + }, + { + 'y': None, + "local_image": None, + 'image': None, + 'randomref': None, + 'dwpose': None, + }] + + if guide_scale is None: + out = model(latent_model_input, self._scale_timesteps(t), **model_kwargs) + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + out + counter[:, :, c] = counter[:, :, c] + 1 + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + # assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[0]) + u_out = model(latent_model_input, self._scale_timesteps(t).repeat(bs_context), **model_kwargs_new[1]) + dim = y_out.size(1) if self.var_type.startswith('fixed') else y_out.size(1) // 2 + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + y_out[j:j+1] + noise_pred_uncond[:, :, c] = noise_pred_uncond[:, :, c] + u_out[j:j+1] + counter[:, :, c] = counter[:, :, c] + 1 + + noise_pred = noise_pred / counter + noise_pred_uncond = noise_pred_uncond / counter + out = torch.cat([ + noise_pred_uncond[:, :dim] + guide_scale * (noise_pred[:, :dim] - noise_pred_uncond[:, :dim]), + noise_pred[:, dim:]], dim=1) # guide_scale=2.5 + + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i(torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - \ + _i(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'v': + x0 = _i(self.sqrt_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_size=32, context_stride=1, context_overlap=4, context_batch_size=1): + r"""Sample from p(x_{t-1} | x_t) using DDIM. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale, context_size, context_stride, context_overlap, context_batch_size) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas ** 2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, noise, context_size, context_stride, context_overlap, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, ddim_timesteps=20, eta=0.0, context_batch_size=1): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + from tqdm import tqdm + + for step in tqdm(steps): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, ddim_timesteps, eta, context_size=context_size, context_stride=context_stride, context_overlap=context_overlap, context_batch_size=context_batch_size) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat([self.alphas_cumprod, self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None, guide_scale=None, ddim_timesteps=20): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, xt, t, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, noise, model, model_kwargs={}, clamp=None, percentile=None, condition_fn=None, guide_scale=None, plms_timesteps=20): + # prepare input + b = noise.size(0) + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, self.num_timesteps // plms_timesteps)).clamp(0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, percentile, condition_fn, guide_scale, plms_timesteps, eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None, weight = None, use_div_loss= False, loss_mask=None): + + # noise = torch.randn_like(x0) if noise is None else noise # [80, 4, 8, 32, 32] + noise = self.sample_loss(x0, noise) + + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: # self.loss_type: mse + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: # self.var_type: 'fixed_small' + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + # target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0], + 'v':_i(self.sqrt_alphas_cumprod, t, xt) * noise - _i(self.sqrt_one_minus_alphas_cumprod, t, xt) * x0}[self.mean_type] + if loss_mask is not None: + loss_mask = loss_mask[:, :, 0, ...].unsqueeze(2) # just use one channel (all channels are same) + loss_mask = loss_mask.permute(0, 2, 1, 3, 4) # b,c,f,h,w + # use masked diffusion + loss = (out * loss_mask - target * loss_mask).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + else: + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2).abs().flatten(1).mean(dim=1) + if weight is not None: + loss = loss*weight + + # div loss + if use_div_loss and self.mean_type == 'eps' and x0.shape[2]>1: + + # derive x0 + x0_ = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) * out + + + # ncfhw, std on f + div_loss = 0.001/(x0_.std(dim=2).flatten(1).mean(dim=1)+1e-4) + # print(div_loss,loss) + loss = loss+div_loss + + # total loss + loss = loss + loss_vlb + elif self.loss_type in ['charbonnier']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([out.detach(), var], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound(x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = {'eps': noise, 'x0': x0, 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0]}[self.mean_type] + loss = torch.sqrt((out - target)**2 + self.epsilon) + if weight is not None: + loss = loss*weight + loss = loss.flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, x0, xt, t, model, model_kwargs={}, clamp=None, percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood(x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, x0, model, model_kwargs={}, clamp=None, percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b = x0.size(0) + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + # noise = torch.randn_like(x0) + noise = self.sample_loss(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound(x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / \ + _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append((pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append((eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t + #return t.float() + + + +def ordered_halving(val): + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + + return as_int / (1 << 64) + + +def context_scheduler( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = False, +): + if num_frames <= context_size: + yield list(range(num_frames)) + return + + context_stride = min( + context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 + ) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + + yield [ + e % num_frames + for e in range(j, j + context_size * context_step, context_step) + ] + diff --git a/tools/modules/diffusions/diffusion_gauss.py b/tools/modules/diffusions/diffusion_gauss.py new file mode 100644 index 0000000..430ab3d --- /dev/null +++ b/tools/modules/diffusions/diffusion_gauss.py @@ -0,0 +1,498 @@ +""" +GaussianDiffusion wraps operators for denoising diffusion models, including the +diffusion and denoising processes, as well as the loss evaluation. +""" +import torch +import torchsde +import random +from tqdm.auto import trange + + +__all__ = ['GaussianDiffusion'] + + +def _i(tensor, t, x): + """ + Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t.to(tensor.device)].view(shape).to(x.device) + + +class BatchedBrownianTree: + """ + A wrapper around torchsde.BrownianTree that enables batches of entropy. + """ + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree( + t0, w0, t1, entropy=s, **kwargs + ) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """ + A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0 = self.transform(torch.as_tensor(sigma_min)) + t1 = self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0 = self.transform(torch.as_tensor(sigma)) + t1 = self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +def get_scalings(sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5 + return c_out, c_in + + +@torch.no_grad() +def sample_dpmpp_2m_sde( + noise, + model, + sigmas, + eta=1., + s_noise=1., + solver_type='midpoint', + show_progress=True +): + """ + DPM-Solver++ (2M) SDE. + """ + assert solver_type in {'heun', 'midpoint'} + + x = noise * sigmas[0] + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) + old_denoised = None + h_last = None + + for i in trange(len(sigmas) - 1, disable=not show_progress): + if sigmas[i] == float('inf'): + # Euler method + denoised = model(noise, sigmas[i]) + x = denoised + sigmas[i + 1] * noise + else: + _, c_in = get_scalings(sigmas[i]) + denoised = model(x * c_in, sigmas[i]) + if sigmas[i + 1] == 0: + # Denoising step + x = denoised + else: + # DPM-Solver++(2M) SDE + t, s = -sigmas[i].log(), -sigmas[i + 1].log() + h = s - t + eta_h = eta * h + + x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \ + (-h - eta_h).expm1().neg() * denoised + + if old_denoised is not None: + r = h_last / h + if solver_type == 'heun': + x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \ + (1 / r) * (denoised - old_denoised) + elif solver_type == 'midpoint': + x = x + 0.5 * (-h - eta_h).expm1().neg() * \ + (1 / r) * (denoised - old_denoised) + + x = x + noise_sampler( + sigmas[i], + sigmas[i + 1] + ) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise + + old_denoised = denoised + h_last = h + return x + + +class GaussianDiffusion(object): + + def __init__(self, sigmas, prediction_type='eps'): + assert prediction_type in {'x0', 'eps', 'v'} + self.sigmas = sigmas.float() # noise coefficients + self.alphas = torch.sqrt(1 - sigmas ** 2).float() # signal coefficients + self.num_timesteps = len(sigmas) + self.prediction_type = prediction_type + + def diffuse(self, x0, t, noise=None): + """ + Add Gaussian noise to signal x0 according to: + q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I). + """ + noise = torch.randn_like(x0) if noise is None else noise + xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise + return xt + + def denoise( + self, + xt, + t, + s, + model, + model_kwargs={}, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None + ): + """ + Apply one step of denoising from the posterior distribution q(x_s | x_t, x0). + Since x0 is not available, estimate the denoising results using the learned + distribution p(x_s | x_t, \hat{x}_0 == f(x_t)). + """ + s = t - 1 if s is None else s + + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s ** 2) + + # precompute variables + betas = 1 - (alphas / alphas_s) ** 2 + coef1 = betas * alphas_s / sigmas ** 2 + coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2) + var = betas * (sigmas_s / sigmas) ** 2 + log_var = torch.log(var).clamp_(-20, 20) + + # prediction + if guide_scale is None: + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + else: + # classifier-free guidance (arXiv:2207.12598) + # model_kwargs[0]: conditional kwargs + # model_kwargs[1]: non-conditional kwargs + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + y_out = model(xt, t=t, **model_kwargs[0]) + if guide_scale == 1.: + out = y_out + else: + u_out = model(xt, t=t, **model_kwargs[1]) + out = u_out + guide_scale * (y_out - u_out) + + # rescale the output according to arXiv:2305.08891 + if guide_rescale is not None: + assert guide_rescale >= 0 and guide_rescale <= 1 + ratio = (y_out.flatten(1).std(dim=1) / ( + out.flatten(1).std(dim=1) + 1e-12 + )).view((-1, ) + (1, ) * (y_out.ndim - 1)) + out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0 + + # compute x0 + if self.prediction_type == 'x0': + x0 = out + elif self.prediction_type == 'eps': + x0 = (xt - sigmas * out) / alphas + elif self.prediction_type == 'v': + x0 = alphas * xt - sigmas * out + else: + raise NotImplementedError( + f'prediction_type {self.prediction_type} not implemented' + ) + + # restrict the range of x0 + if percentile is not None: + # NOTE: percentile should only be used when data is within range [-1, 1] + assert percentile > 0 and percentile <= 1 + s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1) + s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1)) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + + # recompute eps using the restricted x0 + eps = (xt - alphas * x0) / sigmas + + # compute mu (mean of posterior distribution) using the restricted x0 + mu = coef1 * x0 + coef2 * xt + return mu, var, log_var, x0, eps + + @torch.no_grad() + def sample( + self, + noise, + model, + model_kwargs={}, + condition_fn=None, + guide_scale=None, + guide_rescale=None, + clamp=None, + percentile=None, + solver='euler_a', + steps=20, + t_max=None, + t_min=None, + discretization=None, + discard_penultimate_step=None, + return_intermediate=None, + show_progress=False, + seed=-1, + **kwargs + ): + # sanity check + assert isinstance(steps, (int, torch.LongTensor)) + assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1) + assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1) + assert discretization in (None, 'leading', 'linspace', 'trailing') + assert discard_penultimate_step in (None, True, False) + assert return_intermediate in (None, 'x0', 'xt') + + # function of diffusion solver + solver_fn = { + # 'heun': sample_heun, + 'dpmpp_2m_sde': sample_dpmpp_2m_sde + }[solver] + + # options + schedule = 'karras' if 'karras' in solver else None + discretization = discretization or 'linspace' + seed = seed if seed >= 0 else random.randint(0, 2 ** 31) + if isinstance(steps, torch.LongTensor): + discard_penultimate_step = False + if discard_penultimate_step is None: + discard_penultimate_step = True if solver in ( + 'dpm2', + 'dpm2_ancestral', + 'dpmpp_2m_sde', + 'dpm2_karras', + 'dpm2_ancestral_karras', + 'dpmpp_2m_sde_karras' + ) else False + + # function for denoising xt to get x0 + intermediates = [] + def model_fn(xt, sigma): + # denoising + t = self._sigma_to_t(sigma).repeat(len(xt)).round().long() + x0 = self.denoise( + xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, + percentile + )[-2] + + # collect intermediate outputs + if return_intermediate == 'xt': + intermediates.append(xt) + elif return_intermediate == 'x0': + intermediates.append(x0) + return x0 + + # get timesteps + if isinstance(steps, int): + steps += 1 if discard_penultimate_step else 0 + t_max = self.num_timesteps - 1 if t_max is None else t_max + t_min = 0 if t_min is None else t_min + + # discretize timesteps + if discretization == 'leading': + steps = torch.arange( + t_min, t_max + 1, (t_max - t_min + 1) / steps + ).flip(0) + elif discretization == 'linspace': + steps = torch.linspace(t_max, t_min, steps) + elif discretization == 'trailing': + steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps)) + else: + raise NotImplementedError( + f'{discretization} discretization not implemented' + ) + steps = steps.clamp_(t_min, t_max) + steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device) + + # get sigmas + sigmas = self._t_to_sigma(steps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if schedule == 'karras': + if sigmas[0] == float('inf'): + sigmas = karras_schedule( + n=len(steps) - 1, + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas[sigmas < float('inf')].max().item(), + rho=7. + ).to(sigmas) + sigmas = torch.cat([ + sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1]) + ]) + else: + sigmas = karras_schedule( + n=len(steps), + sigma_min=sigmas[sigmas > 0].min().item(), + sigma_max=sigmas.max().item(), + rho=7. + ).to(sigmas) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + if discard_penultimate_step: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + + # sampling + x0 = solver_fn( + noise, + model_fn, + sigmas, + show_progress=show_progress, + **kwargs + ) + return (x0, intermediates) if return_intermediate is not None else x0 + + @torch.no_grad() + def ddim_reverse_sample( + self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + guide_rescale=None, + ddim_timesteps=20, + reverse_steps=600 + ): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = reverse_steps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0, eps = self.denoise( + xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp, + percentile + ) + # derive variables + s = (t + stride).clamp(0, reverse_steps-1) + # hyperparams + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_s = _i(self.alphas, s.clamp(0), xt) + alphas_s[s < 0] = 1. + sigmas_s = torch.sqrt(1 - alphas_s ** 2) + + # reverse sample + mu = alphas_s * x0 + sigmas_s * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop( + self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + guide_rescale=None, + ddim_timesteps=20, + reverse_steps=600 + ): + # prepare input + b = x0.size(0) + xt = x0 + + # reconstruction steps + steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps) + return xt + + def _sigma_to_t(self, sigma): + if sigma == float('inf'): + t = torch.full_like(sigma, len(self.sigmas) - 1) + else: + log_sigmas = torch.sqrt( + self.sigmas ** 2 / (1 - self.sigmas ** 2) + ).log().to(sigma) + log_sigma = sigma.log() + dists = log_sigma - log_sigmas[:, None] + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp( + max=log_sigmas.shape[0] - 2 + ) + high_idx = low_idx + 1 + low, high = log_sigmas[low_idx], log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + t = t.view(sigma.shape) + if t.ndim == 0: + t = t.unsqueeze(0) + return t + + def _t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t) + log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx] + log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf') + return log_sigma.exp() + + def prev_step(self, model_out, t, xt, inference_steps=50): + prev_t = t - self.num_timesteps // inference_steps + + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_prev = _i(self.alphas, prev_t.clamp(0), xt) + alphas_prev[prev_t < 0] = 1. + sigmas_prev = torch.sqrt(1 - alphas_prev ** 2) + + x0 = alphas * xt - sigmas * model_out + eps = (xt - alphas * x0) / sigmas + prev_sample = alphas_prev * x0 + sigmas_prev * eps + return prev_sample + + def next_step(self, model_out, t, xt, inference_steps=50): + t, next_t = min(t - self.num_timesteps // inference_steps, 999), t + + sigmas = _i(self.sigmas, t, xt) + alphas = _i(self.alphas, t, xt) + alphas_next = _i(self.alphas, next_t.clamp(0), xt) + alphas_next[next_t < 0] = 1. + sigmas_next = torch.sqrt(1 - alphas_next ** 2) + + x0 = alphas * xt - sigmas * model_out + eps = (xt - alphas * x0) / sigmas + next_sample = alphas_next * x0 + sigmas_next * eps + return next_sample + + def get_noise_pred_single(self, xt, t, model, model_kwargs): + assert isinstance(model_kwargs, dict) + out = model(xt, t=t, **model_kwargs) + return out + + diff --git a/tools/modules/diffusions/losses.py b/tools/modules/diffusions/losses.py new file mode 100644 index 0000000..d3188d8 --- /dev/null +++ b/tools/modules/diffusions/losses.py @@ -0,0 +1,28 @@ +import torch +import math + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mu1 - mu2) ** 2) * torch.exp(-logvar2)) + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, + log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/tools/modules/diffusions/schedules.py b/tools/modules/diffusions/schedules.py new file mode 100644 index 0000000..4e15870 --- /dev/null +++ b/tools/modules/diffusions/schedules.py @@ -0,0 +1,166 @@ +import math +import torch + + +def beta_schedule(schedule='cosine', + num_timesteps=1000, + zero_terminal_snr=False, + **kwargs): + # compute betas + betas = { + # 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, + 'linear': linear_schedule, + 'linear_sd': linear_sd_schedule, + 'quadratic': quadratic_schedule, + 'cosine': cosine_schedule + }[schedule](num_timesteps, **kwargs) + + if zero_terminal_snr and abs(betas.max() - 1.0) > 0.0001: + betas = rescale_zero_terminal_snr(betas) + + return betas + + +def sigma_schedule(schedule='cosine', + num_timesteps=1000, + zero_terminal_snr=False, + **kwargs): + # compute betas + betas = { + 'logsnr_cosine_interp': logsnr_cosine_interp_schedule, + 'linear': linear_schedule, + 'linear_sd': linear_sd_schedule, + 'quadratic': quadratic_schedule, + 'cosine': cosine_schedule + }[schedule](num_timesteps, **kwargs) + if schedule == 'logsnr_cosine_interp': + sigma = betas + else: + sigma = betas_to_sigmas(betas) + if zero_terminal_snr and abs(sigma.max() - 1.0) > 0.0001: + sigma = rescale_zero_terminal_snr(sigma) + + return sigma + + +def linear_schedule(num_timesteps, init_beta, last_beta, **kwargs): + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + ast_beta = last_beta or scale * 0.02 + return torch.linspace(init_beta, last_beta, num_timesteps, dtype=torch.float64) + +def logsnr_cosine_interp_schedule( + num_timesteps, + scale_min=2, + scale_max=4, + logsnr_min=-15, + logsnr_max=15, + **kwargs): + return logsnrs_to_sigmas( + _logsnr_cosine_interp(num_timesteps, logsnr_min, logsnr_max, scale_min, scale_max)) + +def linear_sd_schedule(num_timesteps, init_beta, last_beta, **kwargs): + return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 + + +def quadratic_schedule(num_timesteps, init_beta, last_beta, **kwargs): + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace(init_beta ** 0.5, last_beta ** 0.5, num_timesteps, dtype=torch.float64) ** 2 + + +def cosine_schedule(num_timesteps, cosine_s=0.008, **kwargs): + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + fn = lambda u: math.cos((u + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2 + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + + +# def cosine_schedule(n, cosine_s=0.008, **kwargs): +# ramp = torch.linspace(0, 1, n + 1) +# square_alphas = torch.cos((ramp + cosine_s) / (1 + cosine_s) * torch.pi / 2) ** 2 +# betas = (1 - square_alphas[1:] / square_alphas[:-1]).clamp(max=0.999) +# return betas_to_sigmas(betas) + + +def betas_to_sigmas(betas): + return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0)) + + +def sigmas_to_betas(sigmas): + square_alphas = 1 - sigmas**2 + betas = 1 - torch.cat( + [square_alphas[:1], square_alphas[1:] / square_alphas[:-1]]) + return betas + + + +def sigmas_to_logsnrs(sigmas): + square_sigmas = sigmas**2 + return torch.log(square_sigmas / (1 - square_sigmas)) + + +def _logsnr_cosine(n, logsnr_min=-15, logsnr_max=15): + t_min = math.atan(math.exp(-0.5 * logsnr_min)) + t_max = math.atan(math.exp(-0.5 * logsnr_max)) + t = torch.linspace(1, 0, n) + logsnrs = -2 * torch.log(torch.tan(t_min + t * (t_max - t_min))) + return logsnrs + + +def _logsnr_cosine_shifted(n, logsnr_min=-15, logsnr_max=15, scale=2): + logsnrs = _logsnr_cosine(n, logsnr_min, logsnr_max) + logsnrs += 2 * math.log(1 / scale) + return logsnrs + +def karras_schedule(n, sigma_min=0.002, sigma_max=80.0, rho=7.0): + ramp = torch.linspace(1, 0, n) + min_inv_rho = sigma_min**(1 / rho) + max_inv_rho = sigma_max**(1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho))**rho + sigmas = torch.sqrt(sigmas**2 / (1 + sigmas**2)) + return sigmas + +def _logsnr_cosine_interp(n, + logsnr_min=-15, + logsnr_max=15, + scale_min=2, + scale_max=4): + t = torch.linspace(1, 0, n) + logsnrs_min = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_min) + logsnrs_max = _logsnr_cosine_shifted(n, logsnr_min, logsnr_max, scale_max) + logsnrs = t * logsnrs_min + (1 - t) * logsnrs_max + return logsnrs + + +def logsnrs_to_sigmas(logsnrs): + return torch.sqrt(torch.sigmoid(-logsnrs)) + + +def rescale_zero_terminal_snr(betas): + """ + Rescale Schedule to Zero Terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. 8 alphas_bar_sqrt_0 = a + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + diff --git a/tools/modules/embedding_manager.py b/tools/modules/embedding_manager.py new file mode 100644 index 0000000..763f3dd --- /dev/null +++ b/tools/modules/embedding_manager.py @@ -0,0 +1,179 @@ +import torch +from torch import nn +import torch.nn.functional as F +import open_clip + +from functools import partial +from ...utils.registry_class import EMBEDMANAGER + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +def get_clip_token_for_string(string): + tokens = open_clip.tokenize(string) + + return tokens[0, 1] + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0] + + +@EMBEDMANAGER.register_class() +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + temporal_prompt_length=1, + token_dim=1024, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu()) + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_clip_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_clip_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def forward_with_text_img( + self, + tokenized_text, + embedded_text, + embedded_img, + ): + device = tokenized_text.device + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding + return embedded_text + + def forward_with_text( + self, + tokenized_text, + embedded_text + ): + device = tokenized_text.device + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + string_to_token = ckpt["string_to_token"] + string_to_param = ckpt["string_to_param"] + for string, token in string_to_token.items(): + self.string_to_token_dict[string] = token + for string, param in string_to_param.items(): + self.string_to_param_dict[string] = param + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss \ No newline at end of file diff --git a/tools/modules/unet/__init__.py b/tools/modules/unet/__init__.py new file mode 100644 index 0000000..3d755e9 --- /dev/null +++ b/tools/modules/unet/__init__.py @@ -0,0 +1,2 @@ +from .unet_unianimate import * + diff --git a/tools/modules/unet/__pycache__/__init__.cpython-310.pyc b/tools/modules/unet/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9905b307d94b3381b27d765769632e9fc1494ad GIT binary patch literal 235 zcmd1j<>g`kf|lOtX-Yu)F^Gc(44TX@8G%BYjJFuI z{4^P(_)GIrOX5rOG86MMa}!HaS27ea1LeWQuRLd~n9$@i480Y-lw8~J= z`0~uWl>GAI_=5bRlEkE(RG36ea%pi%er|kTeoAUFOfWt)FVhiXcuYxteok>rZhlH> j4p0He>X`WW%)HE!_;|g7%3B;Z5Ggy56N^E9;9&p&BZNab literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/__init__.cpython-39.pyc b/tools/modules/unet/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..141624b3f4ee01325718d5fb6b47546120e64760 GIT binary patch literal 179 zcmYe~<>g`kf|lOtX-Yu)F^Gc(44TX@8G%BYjJFuI z{4^P(_)GIrOX5rOG86MMa}!HaS27ea1LeWQFAHa@n9$@i4n9#gTM~JeR slKlLf;+Wk0l+v8k;uw%2G4b)4d6^~g@p=W7w>WGdQg$HAJ_9iW0PUkK>;M1& literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc b/tools/modules/unet/__pycache__/unet_unianimate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41b338a8a2034096911cbc95a7a9b8270cd7ae2e GIT binary patch literal 15790 zcmc(G3y>VwbzT2vrl)uIGqeB20vP-;Ah{$VQj};IBnS`^DFCGLQS1?v@$7WZ?##{) zuX_PvtH+|Wph>#C3RM-UVwdC0Dt5_`osUya97{!al!Grqx=!u4{^_)Dw9V7~M8%$-0%dWNEB6 zR8QqoNXP9&ZTKq7&DuzPG(QSlGC%ea#kTCBJBmHz=+*SPp3eZ2vWI~gjxgiEjM$^V zj7FFVV8-k;FzE<02~5Tw2WC9NOaU`tPXaR;VWxqZvZsNWjxaOMXlHg)0e;4w1%5Wf zudDevTfL&>=I)V0m+q=TQaD8+ij5`bo&i)&4Ms0Fy`s4Ne0|Ze?Mh<_s7z_uDXp}c zm4;WiC5o+8nc)zT+<*s_^Zaj#&tDreB3x!If;uQ*k)@TID3&q=x;JX5&L8z4*clE$a96p%diLq-*IY-q*DG$N+Pr?YSufwd_Cn!SrC~R3xrJ6! zc*Vt<(=E7OT6Mi zl^_gAp;4?mkb+exgyTj!RVdt8E!M&kIb3moy$?ssFFCbZp)m078I%eVAcg=aD#;sD zNvUIo#Fc>C@*yYq5!4u%YZKe4EZkVR`M~)mQfx~8h!FycLr5wlLnJ8_(_Z4q=e51N z_1_l~h0^^{OHs5^NDm`1C?qK?jZ+&!(;+k+LNm4rP4KQK!RMOVBH@9=&J)?_q7S79rHC?UD29j$m2Jq zwBu4aXufW1ofZoa#g=?Swx@dvdN7Zs<86yax(SW_*Y)<0XZl07xvsT`P@eRMHn}|}FNr5XiTD7)696UdcHIxt zoP^YlE;WG%%53!wm=7B|TI-kYPb_8|+2hMuJC}XrbapYjjI0fDtIL98&mrZBk0Ldo zM$J>TVtvsro~CHz(i~CN4)wVgt;dB^wI-&9dm4$b{{20CK_zQHEZ)B;R2uGTU3{3O-$Ovzi8BDX zR5-=9n%#a`cSB2DD1O%+@Ay3MaLNJ7&V8l zRk4<5mA2W{Hf?UZ*i4`~pz1pYH{iDJVH2=5KLIVIw^igcc4n#%HZsf}XLMu~0O2pb7ij`#=xR|w3rD@?yYaDm_=!6ks4C0=4$ zuH5TPeUjiW6UbG1nW_HD9bB|qS7`W(tE_&FfG-QQ#TD_FnEQT$A0X(ju->};IPfm< zb6BgET(MYW_ztY#$*p4R8XI|b`M{O?36x41_Ev625Ld#p<-?)!(>3(cFE?H9yYni( zTPv^1?xmGCx1;m^EF9eFNE8W|CLUX2HsPtl^4@BNa}SbsW=@kPcxa68;|q7;c7ynX z0PqL4k8+{Q*F#$y_myE01r-(^Ctz*|GrKFvJ&Mwv5ogJLm&O~7IMYsUG2(`Tk*n}3 zUA(^zbK>SnsV(n8-75Fst0`~)_>b{(_Zev;m6GVW`>~Lqkxh2FKM2s*opkY*l*q-y z0xiV6IJR;1obt&%+~BLQ&R7SVy5=Qe=fPTrC8ur}>zZdF4}*Z4^b!{Hrm8A8GSq+A zuycC2V$R}qg1RTChCf3C`cB$n)J67^auw92~Za=9Uv*_B+L zyXeqW6uKDzx{Gn?GENYF+jwtndaHDHpRg1>r9iH@WdLwkMowPSW)<*!p*V#D$_#N)CRT>tN))j^-;mlp0onEgw5sT^d^=o=PN`M_B(@3tC# zh_upXn1Ts90{s*I8tU5AwjoDLBWH#id5N#Gd%$5z;$(`J{Dbw+Gm|)T(C}O~vg*p2Zn)Hmcd@NdAaTz2@fIn7X(#U_g=q;~Z>87%a zt6uh2#ao6kl@$)%2Lx}j#fF_l@PknaH`{D6pi-=5i?!M#YmR8L_|c}2`4z`$Em&E6 zwnPu8!ALeAm$S@MJnV?;T1R8l1xpMr6qiw`v5auR2viLH8MOsD$|>&>x63^JQTEKa zMscmBy94UY$xC;&z#z4OK^l9uI4PsDpe;`MxEd788}J?6QBbI_bNXpwm^<+-=<{Kl zhv~NA8y)P|*vDg@jvYO=siIwMMd`#h=^)aKLP^z>J_c+8TUk}bN!##|K7}#M{L?b8 zo>I_9(obx{9*}l}iM$00pqPMD{5U7RED78FAbKA1Su06zhOL%BjU9tOV}dfgdsBTA z&gUD7Z=(0q>&okDS+f(i2_Hv4r>t|iB?rd%DL)yFF^TB`??p~maD&mSnC;zeIp?Jx zBQ4_|15vJ!iQJUv>PUm@(^%!79Kvb_OF9ZjfQLJwd>o>_0XB5SjG)+}Xcp~K(e)r6 zZQ-#%T}8sGuhzUut5$JDD0I1D(J}~K=sEF!uV#BMD+ z!U;@PgIiM%^it77+#?1hgK4M4@0e zONir9Maux6U|-@GLH{LEOGutlk-Uk=FaAL@V!;3Oa3PXze7`_x+A1 zC+t~!(w@4bLYdFm)AkHfF{!@e5X7D!eeIItT{$-p1->LG=n9wPJ<9#wPaqrGOkACi z3gvZ=t2GQwuzKk1KFMS^?#QEL*&~qNa zGla+eN!VzJWvK8{vc4ru-)iWLZOkc%oNU$od#L;Uc|b$%1Mq^7rruHt3Eu!jOa?IN z?*oK03OEEEiV&Hf1ssNlq!dPcm~Mqp_&`cw%tsuikoHr68GjgX+#dm)@W%ir{WRc| zp8=fq#{p;j3BXx@5^&C+2HfM%0D`7!thyJez5W#7K7ST)zdr|0M)-_r7l=mUuf)1YApmz8$ zMly@K-yib+eb>syP@dKJk{>jku)uVgHyvjFqoLBR>pkUx4mJ>>4Y32zDs61$MYOd(Q^! zP)XAiYs%66xvYP-Uw@Yoj>83PGipd7{)ONePP;r5gB{L2Zy4T$LxK3mG-fy-$&T*TiT7U;Q6}| zA58K=a$!I8(gR0KcOEnA0(#=8U&E+%v5MRUO7(rmja^07dhp1@@{ghyNsFD`Fi0sW zl=JMOl=P9Fq>pTq^y^zCZ5z|rUoZo14@moH=h7zjlmV&l-zIfKXgj5DAJ`#t)RUus z=PyX+TcvA=LzFHgiEG91{W4&4Ff^C$Wnr(J4 zVT5~?PsZCgqVb^)Kd0T$;K%f>%{5dB-6M&U-imdqh7eDwS#OygO^qTH)wYp>pY2Ups3F5DVEX*`T{3Pk4gBIC#|NR*_BK6{u6Aa0-pn}XuyAUT! ziB?hG+Ymp-X5>pw3*s>3E4CvxiDN_qrum1MW(Z2u2q>kYQbHuls#O<5OqQKk2BRfYEpQkPtxPAu zDaV${1i3)RP9F1;vq6&W!vkiMKeBj)NJNYZifyW3U+iF zZM6yrRmsT^L}`YAFfAOnSzDET%fqMBd;wkLz_?Pma4oPnJ2-0`I~Z???iA@!6vWRq z8#hndLHdkcL{#jibLQsK%ZL>KH94@T!Y+6XN(S_TJ-e0_ie(Sii^jr*(r*(?ZO244 zXM0OWE~S$LP4+*qC~7;=uOP=q;!Xy0yI{ixC#A0(%v`(jeBtWz7hitq@|g>TtC!DQ zx^n)d%NL)&EEqq-$!?_~XMHen*{K!D5u7GSp*)8x3PD;jR5&#fTv#M(m@@8>5E#5# z6;5uq5X5gem8E4*JVn7Ub{wQ2c}&P?k|ldW5mlCsUyT;`K$JGv(?#|Ui!L&}^c?4D zkaRF_9f2!UfxcF0i3+vO-KO^{gg1-ll zpDS`cX**GB*j+}1xfCiKR~|6thx=0!weyy|dKYq9tpsW{h=myo_vX3+F;^pYhk@cJ zNUV$lQFD@=?hj)}GY^a*wjY1+SBq=Ams(L9Lo-xEozad08i>gm221d~YZ->@djjcH z+(7&;u4WP8isK|9et#~JLd)H7gILtQe?ry_v22a!TPNa{p3+IXiC(&IoG>O35w!4@ zxE{wTDd>1_Ugu4R6F9D!$E^oL)qn@3IJz<8aru7hy?r{OVdy2RYZyWNJw_E%s`0Av z$HxA#UJLWvW`exX@K*L?yc>TnYtCqvaRPlF##@MGS-1BvA4fj^-oAvP&VyoH=OG2( zH1?5dT=sc)CcZZJU}Gq}f&oz}B^qGe&j}nh@z zYE``+dsAsov?tqB?P-6qJp-KfCM`IIoW^YpCrde7X{X^Hz!@7tXy!HTl+wUW2S3e# z`RiI)hWzgEau32r)tGqRi@VPwh5KEIk;dVcp-(%3u={KSvF-M(XI7JMD1er2ZfKn$ zmcvz?djKBcuWCwccxGpqP9Jh3I6@f35yJccZY;v>1uo5c?R|CT@ z{XGbg&EvawTAq{PkYUc>^Q_XzY%1dOjPZI5txOPWTexztACcWja-%v0M{UZV-i#q! zmyib&zw1pS1@{}*8I;Zpk0y{xgkx?a9M+kYHBxYihc|egAtCc4=x@p%g`1XuE3zM# zr*I_{z2ZLGs)~@)ew*&n970!rMjScbgX!16?^&g@2cti*tu)P=EN!Roo|L2v0lv<> zpTfISerkZSy|PB1vam&;G6oL!(_xmB^>|=s5MD_~d^qS2M-=}MdP90QN+lHtFn0F& z!=Pf=A0D7$ziij1Vz@hUa6zE#fNYJBUnlDyKx?KBNPNG|0fa9PfXZRnSAPVD;C>$h z{vc-6K|h45pT!YS+8*D~@EhkZE|pBPlVXMd1rCZ!Nj{ImhkM|yKAlSIIR{L-`n=m7lYHm|+AHk`4*yDX*PIMl^8`E|g zB}XOINBpB*dQo>j`g@p%R-H#=plTuP1!2@>+6|U|haRy~9^G1t2Wg=cfnUIE z!n)b5cmHtpF@MrdV>ZZu6GnDq8|Fx#uL$2BM}LpQ9Un(PNa>vPa|mRk4Z_7~e{3@W z9v|;~fV`K%iF2s^1Zs~ws5XKZpFo`_QRiqoCr8g7KDfyvJ}F-*wQcY5p7M_)gaVz0 zVAN;*2=S79b0;a z=XV^T&={oOEYHwA1kK9o8`>Kx?ojY*qsG4s#BJwh1^%6yAHfaLMd=v9j#zP|ON|@! zvBjcSS{DBk-2@p~rQF0V_ySBm7mh>ddCOh>W2o})u}jG7j$_K&Lpyhrori+awa3uE z{U7K-+yb~NJiz-ii`i1Pj9ZPjc&G9BVz!(E0@<~>flO2ebPknExRrYajn`&SrHc!% zL`Hpy0$=k}1PKDV@iJm?pG~^2*xG~+J6y$ta9Z%~TxlD^J;xs|$~cUqgZ$ypjC+Zl zky+w@0pKdGQ!c{H6<;Ij_X)mEunrK!=sUXNb*42Oe!j|*I|R}ge2pn-5%SC|(2I@R zj0XrB!h+T#JU|?chMh>mlsmEbErNed@awD-E7fp#Dt?=Y-y!(71iwe{y9ED+;717l zfI!-(zll_S!mTt`-DbsxkKS_(LZdif%`XwqRt@4ISH!;{=5quzP6JC)CMQ#n4y(X4 zg*jVuiZ>ncA6RF!5KV~!-nN80(p5Hpo8Uha+#~oj!4ClBP7GWec@c;}y&d}ee}kEL zy(yi@4l?=Su)~78;$Z%w3?4uH$5;u01l*zy(c*8Ss`zO(PSOK!E{P%zl;H~nR_M>l zlj6*lDIW$8{){gfXK?|4C-Dkpv~&vaVgj`ksO9jAoq8*X%K|qP`zl{|lPzNnawmv2 z=mLw+5-&wc4+QRiJ{>umq1g?OzMLQGo8XIsnPQ;gN)g&IFq2Q-rmHTK!(k{T$_Cmi zuJ{waNJgsulqtDx4 zu(WwAka^g+us`>~HXgviDz{(@hjnP*@(kTn$21d`E^{MlMn8(Xgkzco%QOQE8GUIP z^*FDG>KWCF8R%gO^(>sXTl$px1R&4e*JdB6UrI@gr2DaupqY9Wk2`@Sf*andr><)P z=Z95Iq&#`msR2V3>Rjs>FnV~|3`gzqRyhvg@pk~W!HsnBq-=BaecIrJ0=GZ=t#AWW zyeM0p+0u${z*V+^nlNrfP1gMQU`<|rvuPg0dEG6dO}3OJ`j$Q{^zmug$18(<$nk7_ z8;lY01zB_BJJd{}=FiHSU;GX=hf(tvcX`(c@V`nt%!oc`MuGiTJFsKGeqHuKW4~WJ z4g9;3if;~5u`73C^4|iyyG!0}hrg|EjNX+1L4~ugAYrUmX1|6ivup zDp4<2{||eatWL>(?)WMUmQnd1iD{h1a~|VX6gh$@JaogVDh=5TPY0$d%~Blu)qI8@SoTz!$UIilb_lV zS_&KdA*$skUqC3UTj4AMg|aN4-hq<;&ynFRfxn3I_VpBBgI$IH<`et*;u#c(SpvC@ zrkkp**k}C_;Jo!sVCJm1*_K!CjLFo0r8bu7vDBY1 zO@H4QO8sds9Uu7?Z!Q_S`5x|49@wczx9!Tg^rfwzencR=D45uGHE7^Q5Xb-E8$q83 zQ+@%D%To(e(F_goi~g694t~48{gR!jk^iWCO8Jy#?kw4gvHqL#sm#3Kn{j&wSHZmH z&>S=6{j-HeBmAQl{~6P4xpD&M$Gc?<<<&+hym%Q{18Mbqey&zoT<9Nhgnzh2fFGCI zWLP#oT`pFHJiEY476u~iav>RrfYi_@% literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc b/tools/modules/unet/__pycache__/unet_unianimate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d71fc81a7f640232588131c36ecf5ed443da653a GIT binary patch literal 15915 zcmc(GYmgh+bzZ*!G%y%&fEmt(=*W&r z&JWinyX)_}-FVD{lGawmK=r}B=RW$}d(J)goO3TKqoWxGALr<2YJV;i<$o|U{7E2l z0gpGSDvGVxYEvmkziL^Pd9AF;Up@XcT8VN(WgV@VY^BO6nO2)-E1@Z>(n^;z5~DXq zTG?_|rj6!kD_71Tov@S5eK*)T)f{W(%X#3;a^Zc7owhSCDt5+6G{!gd@&qs=b{3dy zjM)#&sGS2Q7h@)Y*=LUdGZtf}fXUkhUk)n*RQwTYOfb|#&qxTipcoNhSF8Hz#(7j z>VdkY28v|_8eqa!t}FM|h5N}c;jMw@(6p?&TlXz1)Lb`AU#Z-61m6`JE}>R)@9UxA zIn7$A+x3>n;|2FIJ(g8%Ry@zL?klf8dGS3rpYt5y-K=}{M*HS-Zv82@4sxBFe!JcD zZnoO?TGR1vuDOmcfwkt=D^ZpC)w^O8o$kluB>@yQqh{6hnVo~4?@?OSSS$W;SLL{CYF!DmCx+!RJGxI*r0au8ySXrc8tMbP<{LYhFUZGIBIf$7 zP+bYt+k$d4H2IQLl6eGAkLh!Jk{L5z9NON)fPKeM!0rd^BW4%-dZYACqO_+LY*Q)* z!%n=IhGIzCscniO9j2e^X@mzR+C8`c@KrqWQ{`;+J}D$Pa}*QrMSy{l2=qXkP&$b~ zv(*)?T|^$gS*4SdDnko&TN}qrTe@#F5>!u{R47U0Qh^#I?8JmZ+BS7y%^)dm`$jj7 zClw^H2sNZKfw6^@i#FYnE%ueAA}W~+f<(n6f91Cxyn)YsRZ!j^kkrdiHO6AS;*~~4 z7J0dJ!o;#uu{~(8YNhFj9PnZ4db{SgDpm+>aCD>mmMrM%M&~ECTWGB@zeeBRd zGDvNJag8KIAdS|kAibq_GAK(0nZVr8J0rdsjM(Oe))_&0IvCmFPMKa1^Pog<6^cgz zs@(m0AEqgORIVNweHoU-rjFM7!mAUDMYni*xoDS)k3LacEG{E!gWc+~;Mhw@ z`QidnL(14Z->kG2?aCAH0{+kODC@@ta~H40`SZ;-1jBm*iKzCg2iSeRXrGnuX%y|= zE9=EZuduO4p-n0!VNrva;8F&%w_L&YNf8n600=THvGj2Q za!$Mlpp=aS>;TfCx)`c%s9K>82^Gg#hXNz0146@Y`(9{3oZV2Tcu6gKf#rSu^mgPt z?qiBNt4^rXT0zTbSqRp8e%G{=ba9%EO9W>K9;&lWp_J6bIx9COh^gY)_Mu4f^X7Ec zXrD5!%N;a5TxWd_FZxdEEIS$LtAQ4aG0maAXc?^M zjG9x&)Um<(-oKq~vFdvhYy3L8m;5B7O*K~##Ozw*95g>~#;nXg+xGnLpuh5M*pv;D zND+;lhuK{u3?zGaXXWtNOd6Dt_t3>o)quaS-FM)gC+R3@(s77f2r_9e4!Xei?8ceh zq%d7=T&N%KAmFwnR3)M9Z~!9vSEZAArSAD*vPm}<7NYH}`paQL`tdLg;ibE_^q=I87LHtY9Q$(Q z;9=>H^Cj|iG(f{l&-AlahSImgY-|Wx6636K6g{2ImlVwP@#Td&u>Tv zxTD_6QMh5}8)8wmLipj^Xmu%|$)`m-8&-gYm&zfQjDa~Ll^OyDwY^rv50KW#QkjGy zoWp>l-}T8IL++_EGcpankFUDf2hTGlaWciV`*1s^j8Y#@&yGn8VsNg1X|VRka1Nw?zJiod*my_YI3uKC3~b$=NK zTv0f5SP(ucR$RM?@XKmj_+GKS%8*W_S*$dhkFGnS&Em(}LgrT-XLUYP#AjR7gF1|4 z^YLg7`ihTzc0=oGjQe1D!Cm7LigcC{Ef#9Y$b(W>!9^+VE5z*~qd&}^DQOCrwVv+{ zs4tNV_qEU8`_?}xqoSbopTgs5P%1COTk#HoLVbh6r-@na?=zq;fHx4OJ4Rr1 zVFkcCNccMJgv6GLc8L|Go7|!^O0x_l4O4nGurzFCO%=9n1W2F9m}TCVdG)-4KGH#I z3s#1-Moi>0pa4n}5&R?n~%-Bw(#H{U^3s0yTC59+U~n@a`@36*#&tD}jmL zvo9$xsWr_`*(SUy{gSf5#g-l#9%DN4m!z;_HW)uk*#6yB=d$#iq&423K$I(- zMAm*A9*<^vpgcc`m5d4LDj)#9^n~&~VDv3aLXXJ^$}Ng^#jaL7AI#C_-x8{8NMu@T zO~1a{tUDrNc|RNzrq3@quJgjGc;Y@@A}#@}AM3L^UWr5gRIhevG&FrtLEz9`3iYPr zB0^YScfy3<7S-j@fWEZ_B`eh72Zg5V+_BoO19xT-@xU-~XW0=>XtElds8*;~D?S>p ztu~#W8g7vw?xS#l7?z{MG_a{cqgC-%!bIJ*9r*lMU8_A0iOO~0A>Bhx+Aa#z?6huD z(duK+K=P~SU}2f+oKEL_T3sL8IreH#H=rBof^$V&WS`uS0&v5IumoE4yP|p1A5=cx z&%^1~U=XXeK8+CGlmeZs!aaup41q8rvlRQ8lnVlEcI)^-=#~44al7u^i4`p7QB)Ct z_58>_-C;G}LMz(ku?aMdQLPvHW9dsp$rOiCOB4x?0E8*avfEWe9=Wee<5RFNag<>2 z5~&}g*GU!f3LcNPhp8G`))<=;#8&Za`%rHA@oobQ)GDJA3?lJ-z)mgUOyIMOedA03 z=@F!}jQs=di@#$?6E})lDoV84A*q-d{ilU9BV+t3RP?Nj>Z|31RLV(A@_{h_+=Am@ zzhn(!(hCpB#5xz*1M>92&mtRHL6QlEWXJ=q!6?M>BTSWg*^>TMB=PymlS@(J_j`#i zJ$Yj(GM|1uGM_GsA4iocC%yNf8%TR)1dsO%m_@`yRCqMSAS1?tj0^o@YqJU>Fxxfy z0c!MM4$zR=9Uc?X)JBSx3JgHRb^z1CAwamSfFsb12(<-8z)^Tpij@msj9L5OFDceo zfOw9T53+!TU=(mXfO%j|1Y>~vgFN75Pyn0?#sQ~;3BZ|PKj3UI1$ZEs1_VuwMB^Y* z2ZKq#L%|H-;b0c9XzQILw$V9?-(!+%;QEWEZ7iuldP>QW)6!FlQCiE`KV9r1? zKB3QG*~nTamXyv(JKZ^j$k!3{e8fN6J%;BvW}t{58IUL<|G_>`@7PXGU-F*rR%4h_%aV=Uw0E}D|y{TA2jTfc3(&iQW(CSk~Ph2`%V;v@d! z_YChN{@eX`V4E4)&}7*={dWZ!=-A0k>W6n@Br~Y{E79D4)71hQZ+utxJX*YmtFZB2 zI$}F;Pei!xlQQ6SfxXJqQ%F52r5usuAT>p;7O#*cY>=H2C{BoZP@E8jpwxZZzle34 zfd)DXt#vFo5{%nZnZ8I9+7mvAG)*{>BRU4{XB5lypt9rX1m4$odxt_4gR%Az9}V z>Tnh~%5iAcF?%lV&t$Yrj`eJBtbv@L;EeylAPqa(m{wMF%;yWiBLU_bIgL60yD^Sg z=nT_7h*;XA0cI&~dCER?Q0bhC`+b!C-b>ti0ncJQ4<_e9&cb2nriYH0ZaSpv3VNEn zml9(bby9pCxhv%CHytSWOB#+ zaDTmjc%TDhDCX#bpdIv}2yi)y4bhe9eP}JwtfEipF?xa!vGRPjgCm~+O7X|ETL%1; zfki1$C31cS-bm(r6TzEmyS4fRBT${*sj!D1aE;s70NtF0W5DMvwsFCsf z*y{WGlW;#;l_e(}k>^d8%t(t7XIH7!io7WzKE;7@rb=3vg!x{#9r1UGWBdXJ`CFOh zgbAC#A*d1XHDa0IqX1#jzO&l)oG?eHsofGz&64?aKM#+dot3jV21P*1^DA!E!BMwt zK2mMFzVm`FBUO2nRu-F%#cLgmV!2^v&2ucoCU9h09*rHwyB7WXk{)J+y#8W2t;KNM zcUp*BRGPA#<=0zq>+)owov^%C(} z5PSnZ4rvshA^16hzXt$pmuXfBv$8c#V(IpYEfypE058cCe;=8cCCUZN!RLC7;91G^ zkFxf)C<3y{Li##kl4G;(hU1IPN_8bBzUEZ?HF&0awYC%{>n>+YGAB&jm1&X~X56)w z<+oQH*9*;MOthTT3}zT^tW2jdQ;sc@Db51jIeC!GnGMryA01(v{E@{uA`v0779mY5 zUN}OwL=)ykn!tnG*mx00@CueXSDVj0KlG>_gfh@bu2$T4E4*mr>g1Wm3@7<^S zUbqdBY@-_3>dA^;m~GX)s%6LL-I-Mjp(!a0VVq{j2V&uP?dF>7n=VasLiKjOz_?z& z@?4mqbWmy>I~-pXJrU_s6ecgW-P`BvF#nWYK@{w^^VIF7XAvO)YX8upin`#%D;dBG z5A0cHRcbyiG>t_IWzZ&^+=+>6&h(d#TuSGLnjC&;QQUT7P(hB5#GMOg_rOLC&Pg9R zoPO^5GuDk~u0H#rYfoLVZd`k6;ritdUAy|sH9ZkeQDVjZy|R+!Z&2Wc?e`IWa&Or#3D<_ zZ^VmxC?*>o=pp+@PFERFdYWYmOHc7p|{j6qV*sf8s8y zWG=PuxL&1&h!3UBwxNwAQQ>G5%;}M+ueMi1^SRZyZu)(qy|&83C3$>SdQ672!X#v> zR$mgIN8gb){Y9pvM)_r?{t>}HA^4{Rze4cO2);n@&k3ZmcrR02Y2sfH{3-!8koYwM zxq|*BQ(Q*k*9ra=!EX@!1A>1DP@b(&o^iZ_AX)#YB|4K3jMGFHj{Y&LPIB=^iWe90 ztlshltB0D(V~p}hthQ~-tCUU0#R`Q2ESt6yr$)VHM3k$M!tvxWWqEWUlBiwI$jg4w zOsn-!ZG?#^ld%U3G1ubhjsnAvlHzFsYEDjx_k$|n5ff*UImFc-^3{^K?y~nc6X7*O z#aEkFPx4na4FmBxT#R`wX=IUS-xIGTlT1$_T|`_fsZPF{OwHn!!``@I>?~@Yku@VM zThIRAWb%xj(X*soLju&-Q##Cdq!Em z=O~{6aD2jBRde-4T|-r41Ob~BpaBbl*qPhDy zBaKP8Zj-^(RsvzX6pta$;`9C#(s0XhDMIN^^C$wTR5bDq0%hG9StAQqd32MPB2qG+ zLywd8KDcfvI46g3JquS?(JyYyp?qB% zCQCb6yeTVbL-4OV7i97FY>*wI?VzkNpe<@Kpp8Mr!*rx2ZT)$$3kbaA}#+aT5{e zh;JOfr;8te>r5UorF2jDCmEt`oZ67ia`!ay`+9l2_qdDD>h@Zl!`E|o@UOYfY( z8F|#>{{!Ys_Yu4?Z|6~RQc`^)IN75Ybq}MzvplBiJ}N_4^HDDdv7U-+pBmJ5Z+`0@ zv0fhAUW-R(5f_1 z_gqjykQ;3fO3nvkTPe)(>F(P&_Xt@Sx9*|#)6sayu9o11iCr$rFmxVBb-)KU)Em60f@Xlsqr5QN!z(yhli(@ zbGU`NC|x615-W~$vvF%bu~_k|%i>SaLs*biYVGQpXTj+6;5tO!xZKw-qRJm(w}@Qs zM|STkyN?3ntCf*|`vC1l8{md$0=_!ESgaOnxH)-;_a;v-7HcISkX@f0%EV!g8hTw~=lc+XvWGa4}h~FXj zU4q{u__qZAj^N`2e@O6a1YZIuPk42A&1=_f_~?DdATo-7pOrpCKwC9TMzbRRH8DKZ zlg4S7k(5bM3iDAFn5Hmin@;7nBmRI*_gS$>Sa{nKZei1Z5PwAQ-w6IY0WJN=Og}Sp zP2_1{L-lUt^Zyhw`21urI4tD#zpF6yaw!lW$l zBDS~qI+raIE@vl9xO9QV4-qd}N)H6?e>opJn~~WKkG@tO8JOUU!*!Qvgm)%H1c0;L7@v|2z(A(4S~&s#n@l+WUsR3C@kjmso657jT#G)EGD zGWeyjO6BAo#S+0Ca@AKiH1Rl8yC$Z6d9lgGH76)@t!u#K;l(l>u*-Yv+C9zv6WIoL z-bF*Ux$;fg;N${#NZINOvK8);il30JKEACL-+*gtvgRjcO_;diKg*iGG+dJx-E0~M zNnUM>X_GBwi5~ZHS@x0H&|oz3TtO4x816-msQ12ESr^6#>Yig={1zWlTXnxCYiDKc zkH)nZ9$Nc1W$jT}dwuAws{1)vl9MH$VhKzNyxIL>iQgyjpC7{e5pPk!GB`4{U1PJt5@M7!iKV}oZp3#{|A!6F7X-~ z#`lS^kI3tx(LEykZB-E?1aec?nc_V)@ePtS(6v(GN?psYp3|U+d?a|uxB#6x7XKFH zFYh!$w*uy%O$!oE3_)UhU#6jXQ94Mmj5z$qZzz>R{t)66>qmBL%^iDiDZjA&Gl&9| z)P+)N$K{`)J3k!V!|(iTqVx-Ryas@2Xoyt|e;L={xA!|-wj=$=xsNFy!&PJTTYUdC z^D$*k@V%vjyz1pIOT(*dz5uHO;YP2V?WN}1*z-+!?`+<6CG*St{|(SmXv^chviaJY zTaB(-hMA$XdbvC~IN*r>;EV7+F0#p>YffvjV#o6V;F%bW5;&cPhn-|U! zj1hsIod1m#CI%pAJBG>Va&mO__!!Zr2(A-+FF~1r^NRx>*$IM&Q74>*Xy=#G*cZvr pJRgbl6A)a*I}H%uYLoI)Ak<7Ps~bka#P@LFWMQPRzwk&>{Y%j5|4RS> literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/util.cpython-310.pyc b/tools/modules/unet/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..674f94ec0e993f161e51ee442ee8c56bab64c9f0 GIT binary patch literal 42320 zcmd_Tdwg8SbsxU>vG2uVu>irRsMU)iWQrCk%aUnXvP_Ax90|51%5GpO>*d}BxC>wx zoO?l1n}s8rvPt<9+HssmW5)^LI8E6mP1;Y}IBnXvwG%su>)37_C%1XWjeqRVPTbIK z9YqrR{hqmxU0^{ULFOd&`a2~vtPMWiI1)WYauX{IE(>4mYy@=O`=w>cRn`;0L&?&O^O zGl`iAr{Ij>{}!j{jN<>KQ*y@ef2&h=#_@mKg9&HC+44-n+0w9E+s{}tJCHKzY(>h} zAY~^~wmI98vOP%Ig_Ir6PNeJ%Qm%3M_Pb|xH+HwKU9C8~oNLz1nd?w)w{tDZT^p3U z9w`;)I;30|q}+g%>zx~razl`EBT{a3ZbHgULCQ@?+2h=dl$(Q;JxJN>Od(||NVyp) zw>Ymr$}57Dz0N-8R=ndBa`ro~M9wRNoLi7`n{xmu2ZEGWAmyNQJ5p{BQuZO`Rn8qq zxg$uq6)CTF?nKI+LCSul+~vFmDX$4qUWt@L&TEnK+92gNq}=Vi4k@n-QVt;H9_L=9 z+#94EM9S-(`;c;9ka9ax-ryWY%Hbg8RYpO71SxkQ<%shjQXULaUX7GDIu9Y` zp&;c>q`b-bE~I={ka8DN9(Eo<$|FI_YmhSSycsEP4pI&w`I8q)DQeNj6k0qwwevvC=ddl!qjb__jUc6xVS=Fd1Rcp^RkeptsDX&>u zxX9Im`_V@`UQIpqz~Zrn<22iINGjKlH|q0Coo3sso>H}?rG~0wg+;$KVo6*~bKioX zhgG(mP^oi?6|k_h8Z7N0Oz?3HBCUNz3)*)?+| z)l2rQUh3||iLE^gZ(&W>4b&U{*buurj?3MRpoz+!%%d7(#Y7dx$%Ha{M$i1PdBS`w zaUr3yC^TjH=KkA!YrgT6pFCMxSZ-h}QcJ2a+dM4=o@*UG^u*(CL%B~h-Da!v#Np23 z>{E|FR6W&fJDpQ*b*ZDg+OdU3nDIn?+4VY$)po~exM5=T@pkjRc5|`jHJ*Id}Y^pu~w+iBLldS`Lzp6z@M6T$98clSc4=2YLX(5cVwzZ+p~ z!M$gHTy)CxlTEL&==y2ZSX!vn8&hdN>3V2P6;XqqYd20+^*i}FNw2aUl}Dy3utb_A zFb490V*YFcOlQ9h#b3pX%gSMx=b9H}cWM$4p25J0l;n zoZ@+_zt-M}W8Ilp8R_L#Gv^Z0PS(ryGSWA5djc`bNxT=aGv*oN{=}2-13(hCM^{Q- zzGq3=_pXeg#WJq(xK(GgS4P`*(DrC=T<+GlUzYaEPHAPlm+z$knN>^_P43+DR^(?tFc^9{G#qq(o1oh>ev1<}5`LLxEA3X71yzzw8 zQ|?X3IIo+t-h@-a=;|`#PFQOF7);0*Okg~X)`*N~dSwfq<;+a{fN|n3X=CCF&r2)l z&7{ZK%SN-8T{n9_IXHWcwG~(kv-+WRGq|NU`O!o)gW=s3XRtRe?;6Y^M)-5m>+o3( z3*p_>YT}!ZuN#kX8IN&{$ESLn@%8ijJm+_wbNl??p=FLF-oY7r4d;4N+MGn2OSl}V zcXIt|>j524Z5_nsL?`4D4fKs=pl{1*F6u0e z!|q;ax*DnOLzzpgdLq$1AgJx2L&WH*TC?47sxfZt!?V%T7fEggXik4G)i&hDJyJyA zYG;(juBdAe+{y^))5Xj9%^`x{x9*)Odd@VTRKu+vBN=l2G6)VxUu~gU^SnmeYj)cHNR)=2o}3v8pRO)0FL*O0 zxkc%A)ddM_xI#Gm)bf%8O68}>L$n<~+h{K@HWZVrZga^`3RB`+3yrov9yK73;>u6$Cw18FmI^`Sh8oMOs;2>eRwZe zQpb9)IScYU@Qe+Fl$uu8qgTz+%>bEe5z~vAD{d<9CucD~o}ctOs(##0V!9j3x0X7m z0LL0&O|L^tbP|Axf>yob`j*>l`?+pIb=>MgbG|W^^s@rPUSr8syHGP$UTtA^CWC^@ zZl~$^MgSNbDb12ign3wc5w`dj$acBpGnhf!GL7Aa{o-z`VC=Q_aLW^Yv9Kr{Y-Yo za`J)n!afW7Jd?yiDfrhs(pcR4ijQ{YAE!Iv~kAXUnbdq|2yfEsU?Y5tK80v}GpdPk^7A2E zz?2qJ)I5{ZPEcF6*={!!%i|@o)oOjA=DO&1gi_RTl>aI&cM^fn3PV$svRMYukTP`l zTyc)~hl5cMVF)#G-eq28PEP~=&l#&&Y-`2}kTB+ub7>;YBsyP7B8HhoEY-7H_7m2# zm~+f}@@(R)F>4000?v18);eQ8VNpJy%R1Nvo-tXwi5+aU4sx^UUh(D#Ay(X9tOC>pdk+ZD{CYuo=dLiv+m)kA? zx6y4-N+H-xy@UCE5Q}BC2tw-T2+_^C<<7DTkTpEuQsbCfX6XoM{mk79wZ&sj?Vi6x z`lTF#Okxro&NgsE-8-)gtw61(YRc(n@^xK;q=nF<&8#qnT=l-$S<u$;4AsoVr@y*rbhC~`tjxVyqaQG<}}edq)EQjS#tdyi-I|;jSzLl+_dso zADEAXdVqlt$noip_t4_fLSqqlq2WBBR7a_&nf-1C(O9bY@YY2zRnh}Ta_*-CRH2^a zopHi9=6&O&x*K;hsS_X>t~Zk+>EQp9^54&pm1T-=ov0p|vF57>K${sLL(j+i@VsVD z(!t0Q$Ge9RfTxtdf@xxIGRfR?rQ9xqf3`7+nJO5m%LTJ!Yy)8;N2+Jgd_I@z-t>~@ za4MVOv%8bJeOAxglPtR!=O~7IcnaKlM>O&wNLa`n#AzWfmR56kMux@DbN_T~xjx_UT)#jP=KvCF?RrBkppc)9 zM8f`76?8KO;VRe8l&t|!VDqP0^FLxBLmI7SX$^{!gquNTF2xN^+2~#~?A2ph!xs#H ze2&*)T<%XGa42U?9aDH51K6Wb5)8oeZ# zJ4l<6hbPGP^^!ZWLGY5KPddI6HT06(AV^?sTgjbpTf$Jd3ChW2dy?+4jZ6%stnHh`5ekp(v6^*oSLB8P|?U^j6uW`Co_hvFpx9To}vz#e_%dR$Os|z(X*I3U@ zZpYSHgC3;Z8ukKU*vYBlh6Vditl9~zaWa;sDf`bCw^?PZ=I;LC)3;fr1MEBJGKov( z^PPwY+is9)hbF-QHMcgT*dAxtJY(WN81ggbHqK|FB|Kt_v6S!u*ot75!0BPu$RQ@r zPy&^j#x~?}v?Tm>H`A)zdQas*rPa+e8EaM!AU_(`Zn0TuRgN>~IPzLk8JRfs23Bx^ z0hgfS#w>nF{bSssq>!?I&MfW*{NxFaub-Tke;|(wjLc-!QZU?pN>mf>x(O1j^-4|! zh4^IECh$ju+ub#E=mYFLWoc~rQ|!Sf5a7k)ct&ygsGnr8%HTKxzd)%qPNBr_*OXy{7;VAatuBNnq0-bImSHZ6@u_98e z3QAOr{^ItKd;bS9J7vqg{byYvHL1@umZ=qXKDT*trT;BQSc{Xd6rNseS#GN$;eU z=?C6^sIu%fu=Kf?kiKKr)u@ELL;WfA(e<$BsB~tf%Lh&$sDR?1bfp5#z;Mu~j~uO& zVG4hYbM`D;?(ZOo35EyN8F@%i2b`IJX>Ty6qlIOs)Pb8kTxwC3z9Bl5v zB$C(W&J6srdjPy|u4kMBt)o8NNuTF#E#FRiY3#L-F2pB~yAg8vXU((LS^I4AY-$#> z3@v%aEB0(k*w27D#gn6`^GUD7TF((mN~mWivkqK*q`T~*4F zO{w|}E0-3YVhicEB)fi*02bG+7q&1y6kQ+Xawf2(5~f*t(Kd3IyRY21;vxdw&tHRW z5WVJw@tD2PJZ%noRlqHo3Q#~Ynt4c3!)ef1RWln8a+;>WKl^r@Vd^k|`=Eod$2!Gs zTrKita*v2=^XZ^4D)>pX=m0Y~U(Sov-Jx-GxZbE{mu(1J>6@j2MxBx#x>1YE)CE{bK5} zodh8b{tC~RQ^4jsuQ131R8vomrka=j4Vr}T4Z>5b+f2`#gCZOnGV=+0CEPoM0)iGo zssdVw71Xlb-MEKNmiN%sKxB3X8q_s%Psmlb_T^$-wX&l9DnKiiS;=B|%w2tEJ1l4* zdNocQlxH5qGr$2+jTU$C#1V15?svy!D8<=IsY`OYJ8`U12YEe3DtoN57qO{I{eq!B zj5lLukfO7tB)5vkedE-Hy!rwPso!EC`^+yhCiC>$j0tr4&Ud?Dy2uU9SlvyWbJce1nbR!};ut-JBcMdb2LyRG5HUCdY zm@&FLhRkqCtyBG&`^Zs|m|{T>QfWV3tva1Lwk-rX;S4sX&7VY(I|-FVL^286FtdcX z;tS>L74tSwW>N?KT(UR)lv(3@jfdxe-kkKu4ac6%3Z*6n*OZ!;PJ@*CnvE~~_NI6I zKBu+>QM&piwA26EM`5Ue-81Y`Kq4Cp>c1hsK7t_pF-|T>XxQ%~GC;#9S0N3WsZZf{0~z8>3mMvnh>)SOkfAi+|6J-EMgwFhc`gg01eyfH>c++tjq8tOksthGq{qz4hw1B1f7x<^X!Qh(= zNDc)JuOna@KMVd9E;OQ8KN z^;D>AOd8_%kk=AQ@!wfEq7?6rC`IZ@lwy~WvhC+VDbgD##h#b;Vw(~T^o~l1bqbo?W{A;09Dh#LnL(;`InsX@mq8RFoJwng;

t!HSixW zI){In9z4p}9Sm9w78ootXfx%h0tM$NdQLC;HfOeV4{~5(f{a1(Ld!Xqy?8q6sVqcU{<45p|Gm#P9CeY$yzy6WUTfcs~n#~ zf&;!wSRhV@dWOMw7{n8Q4!6IDNfHa2Vxqfu-4q9?pnq2oNZtPmPag5@`m$4-GW~R; z&Fzlk!#;uv)GF1d-7Uu;W^wj`4{TK&&GA9((k&gU?BBmXU=R_T>W;BQ)bYJjm6q;k zg}vR0Kv&gTSC!M9j9yh&7YxGG7QJRq{7YP``w;;{iq*+KsEMr-Sky^l z*Nf)U6SfWMd=7RWMcQ~+MYD)9-O1t8djXQo8k;)CfaqEMID?NdkmekX+IB=<8_6*nY0A|nn>42%69>6WEjhEmS_BF=gp2;l{ zunDm%aEyq3HLj%za-y2T3WdFdx=!-AePYS9mx1+3qLpmr#1J&R%kwUluNV5za)gHg z>zz~WK(+&A*-{uCfdI!5u|GAqeks1yav)e1K@07tv{i|pI@M^-9rx5MUMQ5j--)=N ziy1JL!)On~+TZ3ZP+Z4t3O6dW>4A#F=$3*ZjL@{On*-xXq{8bkE|<%sN1;0KwT#q+ zn135;0^^Di1pWrDo}l0yOgKdFdV6RAa|c_|a1)C1HiYKD_o03wrTd@fOJorE22jUb zV_{k0WJUt|AR$@Yzxml{v}XFo5qd-QHDvceW3(79heJT4lCmrYF?%w)qu~(rb2Lyh zX~ug~=*G#B#3fc5L3ds;ZByu0m(Vp>7+qQ6E=m?+iS>(Zb3Wx6)EmS25I%*F4FS5q ze8{7O<`glH5~jpFN|;>yg4R*?QrPUH{{qyU5~ z7E!T1L!ULnE3c?qj0QIX_W?*J{vrn`nB((^K+z6phk)bn(X2slM>W^NIJpt%-E-g> zli(4=bX?u*6<{imSsihp!deCOM+t~XVJZM4dBjU=_UagT8yL+yX~a>YFBb*^t&*4u zWbg)Ma2PZd$ZX^n2)g7KRFA<5gYRQNh*IxkAShiH5t)D+F!t(MM7pnkiQ~LEYw&(n zkM$C4aIa_VT?nSeH`Tu@Qf2i)20y^yLkRrwhdZYl>Zoesh|4mxU60Uu%44?B;~!+~ zX$J3R5Ch|TcuQV*DjOpfEf^X?ctw@oM`>knBIL4@FnJ^D(o_~Unvqs~cW#6Jd;iND zJC(vHU4_(6ph%9L2#}Fk9#u@+6d7H5HF_IvqO#D#1Ui^-BUK1DQbV;dyN7P`qtzwM z!SvB1bwW@0Bg-Eyu1f^47-e}TWQ>&NTY;oekX6e9QAxI8s|KwgwF}r_k!Wx`CEKf% zl%}ei+CeNIl+mObOBV5*kZm2$_r(iftpk(=p5{VQokVZ?AVW4l+!R63&bV*QEi8w{ zhHRDwP##BJgcyF30{sJi{z#+Y9MN_uzBxOq(+>t}_J?avHIzuclazcP@a5k@{qJ`5 ztL#O<817;(BYAZ!jqo#bu&D-5IV;6=;~6n=sV}(0Q0TpYluN8GXWBNUl`!DAT*6ik zi~)=`3b+Rwc=MGfIYyyGm*J1`kXg9212^1-z$_DN!7Ny*RTCR9unuQnh^xh&mUJ5S z908{Aa})(a8wpdI3)cFTDd;G|Nq#1HsQOmE5w66JtM@y6roL&vg~>rLg`E-jX_+eD zGXq!zQ={1%8kK6sMl{z^I*pmJp~sVeLtPdKm@LH~po`dh?FR zaBn;yXfCR6p}G*$08;*c-bTTI2$^^~L_O}G<7NQ$j0C6$w#yoo06?#fdRY2zje3B% zV1nAqA%p>;`!K}B_Krl`=;c?6usP1+u-A9PXwc{tdxbOP-o_jtJES@Tsb$D6n0jI$ zb)3AINRBeoY6qv2Mjh%J2G`@|{bHzY)F!C@Xq{%gg1z-|kNhl9Rsc3r6Mg(#m1F!F z1_Hs?vPuD2fnZ_pu4kHnwg9xiGO<{nu$Nr2n83EHfe;yoPQxz5X}MYGM71P(YE&bF z%}}7anXO7Q(z|c4!EZA73j}^7!j5(g+;XYCe2A6|!*Jv~0pRX!csgfoh9iM5fip|8 znc{|ucG&ji?$prr*e3&zFsI&Y5nYBq#!12ti#)FdoL{@;je8`90Ex5+&Y%A||WO(4a#CtkG8?%u1e%aBr9)4}Hd)q@?0MO&mj z?_h5<3Ctx6B!Rm@0-3Xc1lmaCPfemii0?nkxTeONGs$WjTi?zn`uEg-7aY8YH4fQR z{|sJ5O(C;8oU@NG)}}m@x`kCg!uKLN8p^&9vHTKSy6p-rMYzP`fCWFsiCYfiQU_Q+ zbq4Bo2AeTTzs#0i#g?uaI`R91med`rMS8yJbaTEw&ibfZm)Yaw4gA9Hie*5^<-}SF z9Iyee^Gu>oGnMp{4?ghV!;cRgQz5cIPuT6pO|yru=1Lgc|!UXOB9$v!$Qlz0)T za;~Tavp^CDoLj=#6D?(Vz^^ARyk(UMlm1%XZ=MrMaqg#pF(5NAHfas2=jXe+G75Col;<7yBC(&xbt zl<>K*yIlI+5#8NzWtqdj$1=8c$x*+hrIo+KOo#lipV zCXXU$RvJ$x0N`aB1KuTijRi=voT$2IT^@sCVX2%OHK>XeGp ze(=mGvUk&pbsrIAH6VfH7K9f^$UTmrA5ck-CSckJO&txZ18u3h+k<5~F42-g*C~cV zYi8T_NDfobnWb3*#@5i(Q41<%c_!4Pey*AmDI=HaR6d4woWpH<`nZO`y=Xc$@IUy# z!;j0vissX|aOgQS(iK+nAr83BL(FOyqOs}|yHIT*N_(IotK=&jb;5)`j%(Vf7ia}! zznHPi7i?^ZczhP-2orsq|KY0$70DP1lue;0#vB!T)GJH)!M zaUbt*G*}Blu<~$F32GN~qRG_?*7;=6 zWTlCUVD8&&1O~uCB|xR_q3Fqm%Li#y>^A4vX*_k{{(Y5$xL}fe`~CYX56$8qEQw)1 zEbm_F|MiVRr(R!Hs4jdJ*s9*@o$6H3V599~5Q3KxWUBdxtE zVG(MG+Ooj4Cn0a6@&LY1^u`kBCPh>W#qp<6@r7*P zV$6io;TJIGk+mldzR~~{UhvqOYm!e`APhxulr!Lr-%$siRO$kWfJ((E?m@tSL< z&|6bP_K6;QYA`zYU`=HM^EY7lSP;hq)mhk-<5buz#so5?cknS8E*UA(P1@H2f|oit z+ku0j9J?BW#~BPk8h_1&UPe7UTY!WAlmd!85K^U5G=u>Lfr#aL6b(B4b+#T?8I#tx zqph!VVl;uUjbf&hRZOkpk%3=GmDJmiJC*ZC!X_Vk^wqcBrv3qY9XGitaI5cT!5Icm zFraU~4JRwU3*f%})27@^S`!$2raX(0j|pcP@_p>>7z;gK+R;1L`qiO}Er z$3o{2@rje+%?3xk_2YD%!VU(B1C90L+C!v4?oy-PoCUQKQd5avkDTey{)Pxz9zUv< z8zJG0-h$MuuUm|05y6f}>y$x+KgM}s1t#1d#Om8X^-1dkgGF!v zW5NO8<)A0h{^}O+Dk-%j zGP>kMqD>wp$@7>f2_IikpCvWwHT(*K5F6e$c?0D2zrz`0joi(4rosJ(saAClkhY_} zwU<=?%-cyn5`E8bw07W3Qg$^Q4@*X32g+pBXT4og7wUrqFYjc9dxDqFKrA1eRuUFU zBl34(qV#OSKg*kOGxw1nKe2uDmSZsj<_qRca+PVY+|s? z^sDcvPlxIXbEq?H}K{!epPo@BA1^R@PI5rg8%osKc@Bx66=QN1p>14)4-vI^>XQr6frp}nSR_BIJu~IF4 zni*%$TbVP42cuyApHHZVaK0HW(Oh_ap8Hy~pHZ=r>S?j8Q~^p^BI7rg_pAM_9K=E?L{QT>QF z3hNATcutF-Zvwg2U2NV``ae*<)~$!6^n^1aQnlZ%vQ0l~*v39h9Ji?`5`$(|;m@ z?C}Xf3tWZj`*0g#7DAJH5A)CPku)$OK;M%HF_0Ki!_FCa4uGj$YzU$iXB#Xkro#04 z1RT5`3<8T@|?D7(`Ne>Shycsp!k*GBHm?A#8NkfUYe;bBJozxeIU z1K8@ZTpdLQejz#wwt9+3LTdLT5Uh1yrJuX`e_8%BFb;oL2~ZkSzzdU>aC-BR9g1i^k+F@Lx4#Ap$ND;e<9oj(N&C8qOls4Yagk06s^q6Wv5fN&w zJsnz_>L46if@knxjS2rj6g@@d=`jIRn*Nl*Er6D0r_~SO^)9{$r?Vmgg38FI2LpL) z^ RcpEI={5xK2DkELnzcEJ9f`M5p=F9Uc{7l{2{qV`l6`c6Y?Gh8d#=%LoL*A* z{5pq>b8J36ndsgS4ORe?eJ6v$C#DQPuX#7&$S`}HB(45yF~;@akBbotTqIS)g$=I& zNpBh211y80Xlw$2ji^`YD2}&#U;S;vPoj3OX^iQp3 zQ5HM@b_)M>U-Yhe3!a*U$GZfM$iY)yYg>5iW-6`yeyrO`3)qVEb85TB2|va$6dJM1<15kRAj0G&NTaKf_<+$cG4Bq3W6ZY(Hv ziCpc=Hp*y=I|`jg11!`_Af-_hPD@BS5hECm3i_|1LD-F*yv0zHt@v>c`9~RijKR~) zdM{(LKM;tJ&A?AE?I#h06DAu9jxH{S$GHn`fr@GNt7rUH-e9# z;DBtrX5lGmADo#08%Ju);h$$KR3XDA7v~R(m&>~p)9XkY_=WR5c-{^I zb)2citO#mKZy@a%zj$Bhwp4J_be+54@W(hWbRyyYG6G-+D*ClyD-XN^MBs;wdPewV zprp;6cq!ChV=ga^sm6h8k@UfrKO~c&q;dzpn<9q;;7lZ(P$mibii3O)BX1O#3KGe{ zYNiyG)+NW_wX?iB4lo;QO+=U_HBG(~zc2oZxEAU|G1ns%M@N~jKF$dj#3J?tr;z4n z>1-%)7R!HvXMUEY?_(fBXz>+ZXKN92t^FAI@Gw5{2a#R?m>$< z4`{*hghA_=k8`~mA7W^J$apmIL`pkzg~L}ml8MD0f#E!?J#2hR1C~J)HT2S$Bpc4$ zLi-_$f(d?>!?7`$H9Vgd+(J-D;<*%`TZ6R->E~G=5%TX-yaPPxg|beF(6p>J0aC)r`DKtmeu?jgFj;M#|-{CgSRsHAqGFp;71Vn zW_=zP>;iaXWf=A>0bgmKg!UdIXYB{Pz6uaqHY!#TP+LM89PLc5gWUj$!d3y+oLA=S zxQf5SC8GH{9E_cmm0d^o27VFZB(3h_>!`Q1afWKZqDr)w!+E^PP?$Z~XR*)`bTfTq zfZO9D`k1C15|ns$;~e!s$WU2>y!vmUdJDsUq%q)Y%lj+whPbVNrrF ze7nI{HtdQa8$@k_UFkf%Bj&3am!>pF&@wWmuL8kqNP#@IitBU;_uz z1mP!n6mUZD2SLHB_@#g-7H7`kH{w3{oKM5&{C)8D06u61XG4HdaNZpF%7PLJLS7N7 z=7E>M9l_&14$C<-2}j8mKYs+r)smJUNE?wf@N9kMfcx-F4j@qAhg({Cw7<;%z=HUs z67a2#k03tcjjR^&sUkSSrt4{x2bZUw>4B(%(;5Y6@#9~g1U@FYmH*t>;?N%cv=syM z!Zha4vH@+EdZpD-2WRMYi2G4&DoUdXv^WOG*JX^;IKKt4GTJ^AjM87xV+O~`{VQ(- z&a)w8!m+Xp;#-^%jKCOTlSqwI%BXEC#sD9968E8HXH@==@ju?^3aw%kw&9~W?c~+piggv0502uzY@#=AOZ&eNfFpfd}>QzVM{u&PE;PcR)d*?rXm$U zb|7@`%hP@`0uj6RvVpo#8C`Pp`4HjlarzJud0zKAeJ}1jbTn}35vI4)W2~llOgC@< zZdmcxEz#vFJbUPSI$bsBXumc`!@r{t2oeeS>G?)u3A+Bc8Myiw+Ki(+Dm`PpEEGEZ z*sUdPNBQ1D5{cB6pdDfOr+GWWz-1tqN6;5Bk3UI#1_kgX)t+ZcO00yv_CGWF+qOI!y1glaB;4ALXW!z?Lembeb=<}!=P z-XjBIEDo6c@@xm+q~mvvTAI^$oVn0!0(v^82)GZ*pRKyn}~?N2i4h`-QtF1 z!UPkFttrk+nG;FTWh8U>pb3O+cE-*icOTq_?*fXR#03W>K-%~e$O}7C8MA_I(PP12Y0f1#UOE%E1w(MxRs|HB{S z+ad56u!vle7V zS!h4G^MgDKnh0_?&~du2g$)Qj)LPguK)MRTNtxJ4dkoebkanGe7=!m{$HwjvX&%^r zNh{%QtlyjMY)THOX1wx%FWYgnvgA#$9eUr3+YxVnTiWBy3)|9Z(%PAU)+Xgx>Q<3` z=V9yfQy86>;TDi`fN>7!jR?doU{TxxS{xuc0@Q8+;SMmR-2q-OAW;>e6L+!tFy*f? z_BjMF&z02QV1g792VuX-+s`wQl`jI{rN2hUt}O#nsr>K_?JE4 zsNZG*IVzBv<)L$TCKuE~yZ^5+_Z9}f!{Dn7eiy+^J}d+Qws;=B5E!0FgEA;GP(GiC zhx9YNjlBkbmMG}g7zqFRX~tv(DNp!0j^LXR$-y8uknsWKF*eM}tT-Ztc51kzS>{5A zvFl1RFtEiRLBm{;hjCr{3_1q*OsJ%Tk9hI>E*?CwyCad`S?h^p=yw)gXANO}h0P_y z+e*5x0P_a1A@n<7+=K}Zg?7kQuajJEy75iM^TGtLD+7aYDFY9mx(0m7QJ7X9#la>& z3uYO6778CvJS-*o*(TUE=+f;Yl18h9D6OxCJ9P~xWA?~Si}%~GCrrCLp#PNOGGba{NDye=E>>Y3o*gNfu z7hfj=dp{C9t4Hli(rYlS(Wv!50mfUNDV67;up%(&uy@q z(4KcIv(N3=U~-{-@Pbi04XQ(|-heu`0KJSUFT;JO!a&Cuy|^$b8J$|M)b zXDY!%fpLc%=+GIn6}fqCGKesB?`%a)>2z4Y8%lkN>*EgqkbH@#Z(#*uDx&_7+3<3w zTas#cLrap`=h$xey85wzzrA2-fkiBZ~g;0=RmL>aQqF9g0t$;HJUh_#$V-Y)LI*I4A$4E`&F*C2>L2i2^$hH+ZI8=?## z0YMd1S_5T>g-A+1<-j<7UrmV9#$EXM-(GN8uu}s+MUTZr*p$#)b;MQe9ELJN$x5c5 zq#p_WN^EW#{_imwKg~`S@j7Bbk|8ez=ftm2%;TsJ{6Z$deIt_lxhbtXz#fVUQkWUU z4vOKTSR0LB$whI*EB0|zedqJluW;I~MVtDJE<7{e#qLq?8U7e&_bgoQHsrn}o&_Fk z&cUa?@GLew6$49x*Py3fNT=bc_-Z^0g`C59mh?;KS+@5TJe!Z6;&3wRgc^0>VX zybIdw<6X}Feey1Z5ArUeK3(V1I0=`=7&dsp%7Tl*cdX#xax<64zlu8g4x;}BZ|~*Y z{v~6dLjd18k|0^%VCpv+yc}L>_UReFKOpYVeE+93<455%s>_^4k+A4AFN z-;$+>-z9!)VKctQ_m9|&F?L?q4EZt-*$a(3uEb`*OR;6Y058RY_M7uz6J;N1HJ-oL0e8&Y&y)tN)q3U z!@j~^1Mh>}0>p9|DkE@dQ5jCtfoFQ3=0`#*4(f`1u>2Hz0=h2JJe8u*eD%F^bOJ&yU8 zLcoJS0Ed2|Om7T)3Hg^2_?Lfj;)hT+ss0fB3#|Mj{skuhFdCx+{L6RbO-kXLuX|z) z0xLdK?&Dv^AlR}Y3?l~v8&q(Cw23j|_|(L-`U2tiw-|hpfnez`BL*SSeN1=*ga6Ec zOo;kz247+DI}ELpgd)_Z~){J zpc1?ypU@|L*>jaOAFvGSe)jF_9Gr%erMeF}FxO9Ivfx@z=GTS($e%RM;kp2d~M z%79)QF?`H;EzodN3V@V6y{82_nDF@M|6unk48**F+MywC!donG2ZIF!e)=)(c^IF; z5gA{f?&SmvDF(D((W;ohl5(9YMs~K81Niyyv^+P&GM}+L;5Z5_b5i*NDxGL_Z#0fZ5(4wdYb^?AA*X&Tf1gfjH+un1voQ!m`-P?P#= zl<3|N4_2_a`ei8E{dRXJYYv=iACKK@Gp!Y7io5J}nGQP`NzbFc!s~+UfdrUhiTQc< z_+ADg02e5`NZ8@h)BqI7M`Y(CNKUlDcQgN&S#nSY&67Evt3;&y5tb2XlF<=7@lm9` zmmozhoQ5>;gh!u004thmZWtK|wrZ^A7y}}Z%_YBIM{hh<&1FxY)uobR+6J`4Y2r^k zfLwg-pdUwBPduSm4@-#x;t+BJ5KF*k|TyisjdC=rFW7cTm>qDO~ zK7c!_B6Vj?C4KlZo(Ox!r~TACKFK;a^b>bK%FYpJb{XAm8#)#ejZMY$-@rTnAgWN` z#h2$@9d!--Lh3MJ!@Tax*W4=hdp^ngBYhjgw;dGVxk&c(A{97a*JtxC+$+pSov<1H za)^xJ;^?{WL@+qbbNJSzOf!9*;>*@Ooa6ylwv+g#1cWUR+obZ|3*dTm*z9 z)7Z9(mEZ|Fxqx62AD5E!Ownqu{R?lQfA5BfRNK`AZc32-}U& zT2?V_)AgeM9dhyA#bGjInQKunZV;4gs@GAUVts!HJ%;)Zs6hQk)?$dCQQ5yMk($TM z#x)qIFB2nt1_isX8N!uDosYn&pKgY_2ibzF-(Zshv(ikoYyL&H@SAMmrkB*h04X4T z04foHRi9@I|B1mLF!)0TLqO(F+29`|vwPze8w@e#pRk@s2sE4W+JB8AKhJ9alx2Ch zDJ%^6(lM=_vhc(p92L21ShENAPT-G#Py0W?Q}81(VnmuBp*lybNc;7&8l!S_!XM+j zxWruwL`L$nQ7e zJC0`p-I&Y?;Ep<9Y&41H6%iWLVnTcc_X@K71&$_F_M+0)b185`$h;P-PgWpAMaJai z?p@(z4jGm!g6BrYfVf?BRelbbI5$+5A|$vB+69uar1l#e2Y8IsD!=Cnj|qrnaqffGeHKn;q+k?_X9BZ)TYMtG zJgfx+B}`a>Mrr(%k%1=$(n?I@8oL04`=3IOz>vsxNGK8mBuE_Z8aNSs2coiF!g1X> zroPC83EPyO1hCC-6l09_Q4?+NPkYv3 zs{o(xl2t%ndvFzu0OE@{MUZI0;2WzzAbwriShxz*WegHzd;-=<(3zJE)(I5$lRQO! zK!do%53V~mB*3i58ek*5Z4~A|d81rzD4Lu*_X*?1n0KH?w#JmJxK7{%O!6af} z4qO)ytW7#_-W;rziM|6Rr1%PH2C~pC#)bfMi*H13nZV56N8r4DW8Xvm$7W@lB5uyg zn^`R`4aVCQ)2;@_>Rp)Ht0HBs3>-Mn`cjed{{g^#nV?>L*^2{;hgJ#px*!g+IAFCa zB%(iUd~k6D9_2ZjFKcmVkRkRLxY29ncKz>FFKd}BbN25+MOR-Yb~0Wv{qc!3=)pA` zKoCx8Ucw5RM$?esT?mXoBPNkr34Iv&U4Y z=E(OUx8pGwFXaN;#U?&~En+nxbF$C+Al05zXXoMZ&{H*W zRI6Rte3$@qFPigh9M^KhhnR1<%Zq-d-8fan=`mMK&c)gs=a*3iu5cFRsF(UtJRC~= z@`O)*Vhwia6xzzp>ntkP=eM3@6Uk=VY2cfd`uM&k+C@GmuSDVeKjkakh)4>ec^Z39 z7$uz1(tkyZj%{|ysjY3u@0QlT@_@XA;LqzYE|;5b9yH|9;-0Dh;~Z&0PCNzovw#~M zFoxalHXJhnHw`2TlnU++5CAvGcw5Sl1P~a27vb>1w*+LI>Ss6@w#*X-P}P0QfzuvQ zd#vM5M)1m?4UyY6Sf-0VjsXAckD%bHZg*JE=5BY0Fh0RnxSnY=RWM6-X_pd2LJ`l@ zJ?dxiPb*+Ph!|j_qw2@^x7%9miwdw^YBbxOC0A53es*@b4Og-qd>x7y!XFPG*e@AS zUep%+TpSm#3^OB9DFB`)^w!_8&^)$3KEl;-cnZEM1r-sGzje0ORN{6H{n;P8QB)M_ zsI3fM$KYWGB#YXS_xB*?ryl7z%kUUkW%9=ve3HRM=9wLrpOZb=RGl`y&(-kM3(WgA zgEXd4zJ8|Vgx_M?XBiw}##b4;ma(@pSY*&-FJDuxly926?SBJ*_B5pc literal 0 HcmV?d00001 diff --git a/tools/modules/unet/__pycache__/util.cpython-39.pyc b/tools/modules/unet/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d500b1fb746cf9e381c45ebcf30c27dd10bc220f GIT binary patch literal 44250 zcmeIb3!Ge6c_&!+QO~Zf?yhdN9+u^jY-BaajU^icM))OLMqoE!$&fVi6s5Y=>Z~3PPz5oAr?xU)@S`WhRFFP|;zdCj9x#ymH?m6H2-siiO;o)2YpVhHX&fPg~B>p`Q zx<57!?!(XBXC)F&!Z8*Tv-;ndHRRZwHRa!$wdCIpj+0BN*_6S$=AyBbo=r>MTFfkE zXS0&G7jsKPvqO?M7V}FfGhrl_hGz>w$w*K#Iy)*Q$;Gjy;%rg!sm1Z7iP;Isrx&*^ zP0mgt|1KxvWS=%>w>vp!=;_4l4kzyn4FaAfw)?yK!<+`M*+ zv)j3G-JE?D>g{oELcN=UdizmQa`vKRZ%}e8O7=N7qvYnG8kD>yC^_gHbZ)~l-i9;N&TDbzwZWNbl-%wdLdl_^ z|ZyPVgfrwK6a}*^f|5s^M^W-$2gIgdiM)lJu_2=pQ_cH?#j}J z;b&E?qEw}MriSA5az%Od$|ABOkF~stdh(&A(>2GbH_xD`Se>s`7nWP~rdK|%D$C0? zh0^4sZ-4NiN2*wX(I<^u5B5&m`sy$8gN2Go8d*Vm*i{zlk8ZX)RDy5>5hdbv1Y0U+6_Ov5s|?-e(oM5 zbu{+GB$_c+O*CPgODMZzbj%+x&)O#v8wr&~r76odr*HSIh1!#T@?2$crM7{QSyr{V z`UR=*;+;qDKJ~a;Q|_s{TW_>ZJ>IO}->ffHyxJ+R)mn5E$Q=+!3c<^{TC<<$HGGHva{`M0?L-tKyU&SZq}nrjH=4EV}nh#}%hcKUw!`ORk?* zwdKW1wKkRZldgwbsRFL|bIscMvhJIolkzg}qK0r%>4o zD3XMcF>KQ|CXG=eZ{!TM17%f{XM)dm{Bw|G4yQ4=m^ZUyG!oBZI2*>Qg%q=d)WS>| z4I6W0G?J_7cy1l*jL}JL*Rz|Ea$Dv+lJPHJ_?OLY=X z+R0otSMwcnE#1kl8%9D(vjaaf($cZNTn_@wE`zcc(v%PCy8x?AlGJJ#zH ztHYh#TINzh?vwR$os9I&+?hZQ)Z>H5T{JHm47o+tve3iyqzj(S5Ki=|JzxH^V= z7V#U8@9T_oinzNS+;<{U#1*cWyP!9d9%nHd z&0==bEZ)~Ui;lG&cu3}VU*BAA>r8$;5zS?Ibj7*sjLY+e^J%G1NYBG-HB^L0S8It+ zKi)Sc<1!}W&sk^xVTUu{JI^){9_D$$d=9z4!<{*rcm`+ejhypIx#uM2;V<#~inKe~ zzw$aJF!c6b3{BJ$VQ8YEGsMvORU&v}BJmNy%NlbL>ui{cteD8#W63Uz8(sSX2}lnWxz0W+B6oEI6LwjUt-2Ap z+jzuYz3>V?XOO|?Ki@lB@ajvoaxJ8&f}&@Ko{URo$EvNSSG(ZFMT)&_XKKxA%c-g9 zHVD>~DHK6%!x`V0eu3Q%Zj!jKonERt&SGu4oe46BqTJzj#;NIce0#oDai$8h8B#7# zlBL-Z$*G!KK20L!`bChV<(8@}mMflDYkKuo(;tq?(9?6X!{OECrIkf*c2o{gyIFQY z;A*ar6+gAI?0}B>DY6Vr$IsT9D@!%SBCB0r_LIVX_|{^r>5oS@kY+ey?$mZgi1DrJ z+!?<>Di_Zzj#rk-aj(_QsHtvY^$ea$`sELGb;;*g@_}7cs0F!UcfO^(>WYkn3%bkR zt4XvnI}{9kx$a!>jk=$zH+i+6Zt2>-(eN#vvyqf%b9&Tu?egqUFyK7aP$KXmaD3wx zLAHdOyPramFl{qsL|$bU{$N0}NXW7W-tP*RX zPR1A=we6x@F(GNe%v+K}&4Wf!+Se`Mx}dF$fm)_*?Y4q5X2u##qX(usf?oD84`lwh zCXV9gei%sy?A|Ui5W?=QCOl&ed8C$OBDKNf8EeT7i0@j8`Ju!***?oN8fl!*bgXp) z|5^T@F@eZ!Fa}^g9Agp;#Db+h#F`x&b#fh>Y{dt`ayr%r%{h?mo@>DWO{HekUi7Ly zdMkj1{GwjQTx(N9esT^osDRJ5us*SFky z)6cbQs^yj!>kGB1#ckCc-v~e=!XExacn*s$!V{mt zwJsMx!ps;M%#dm9dD*u1VD=9f>a8g6!D4&}q68WH3^HJ(5-VW;H;fqp!6`!>N3jOS z_};171)x>;^El=5`h=k#LE1B~ERWjexOr&(qM`Dg#0ku1$3T7;QXF+~NA)&T0lH#4 zegtVbP~C|VtM^P;LRea}*Ar_=z?HEMCU*fMO2=ADttZrJ2ZD!V z9Z68^aPVA&9X|Pn>8G69xq7wcCmrue$U*`=hI$Mqk4>f3J^1(SR2kJgr2 zkOk-F>eV_Z!BGIYdw)n2UPwTqn0bMq_QHKgzK)-Ks_8I?f1c))TKKsiLgFPZCf1B~ zXjLHF*%OJ?v}Z!bGcID;cYqwqR>wwa<7}dw^io*fMkh_F@iHc|oRJc%t0cRrBK9}qy2b{~wb*WBjz_}dc&LJ<` z$#k-v+(ir0+(FzwzdF2H@N$izM!o~-u>)v9FNU4ugNd6^)oB26&GM@c)+K|Y2~tZ1h{_Y$YN#bv{Sj~e@D4XEl18MzG9mh(7%b7+ONMd z%mTrluPCRR`PZ!psuq%uyJm$sb}u`aeTN@7{-mI7Pl@9^6Sj znRzjnD;m4yOC?jv$iI@pxlwZm_)l=9U9)JuVy4=!x~7?&N~(W=o6V>)QuSUY6(m#U z%*;lL2)kKZ7>E9}W;@wYInzB1-|h3v8AQx_Sc1 z@8ahY=Vi>osFYN3(DMl~qI)Nb`>3kzBuQ0kHs~rTEEuHOq)=E)g;3bEP}qsbYKs-G zey;Y8mRqN!^#I^e9`l6~T7BWWk@TQm-jqKQVJ64VOhwx=q2{@Mo@CAe98{Xsnj&lEXCoc3zg-30i~+dJl`~~)@Ds@V$833x zi40}5n&oaVmL%K^P9{=05o+2u4tn&M-tYm#53l3zD1PogLgG*lnHWllC}8eRtQvT1 z8>}Eu-8v}IL;~6i2oaz_HmMSrK@(yK_>E2yOdhn#IEO39{&kW!Kv3|Kq)|G516t@L zcO)jHwe93kD4sACj)FQg5mV9~-XqfiUF)RRsjC6|hZTFN8l^X3# zow<7H5Y9(~+Ah>fjnX{N%;Q{RDkIaT?q>u4gb5d*`e#g{i`1uZh?+vkeu}3k5BSNm z99=)TAb(hWi_FYs)p9V{eo71wZr}6>u6iveg2XRDW)0IoQN6u;;Lrz{dCJl_@{{ZV zR|+01&S%ueSpFkSev}CrL_bfhO1-*N^X6NQek@{2^>HS$iu>^6CvakfeWQUQV}X-5 z65lKgfdGsvo~z>5^C3IJk9!bFgeeV59U_B3-w$jFUAdEWaLoUu%}2(OP`6L~QAR(GGiqhf1LG=Uiz(Gc6Si>ddi|)f|U7_%P2|v+#3&4@pcWJQ&6xuJ7}#O9{`uXuA6@ z!6rlN4hTo`61eq=#Iw*GH^6N>5bePllEItj40P>`YOR0+%QIh{uRu~- z1kQy@2D%(@V-TE(mjz94SpLvLO`%5Ft%HS4xt^*!U}cvpVgc>ZvQP^`)0=SDWZy)e zy=N8Z>ypjEyM$rw7~q`EOQ5@#(^|tlWq?$k;~iI!_+y~7REd-+_>rbUg%lfE!543P zpyJl6UEnPS#0{+-S8qof8)@BC$`Nr?eU^>O9nSL(>82FBet`fMx2+enusJuP-bpTF zBIJ~5zC3E!ue4vgdA&tMxtpT~t00!njq#YJ&|GW|Hd88oraU3aj9l7 z9_A#?fMfRUCezeWfc9YrBac-IDO@e^QF4!pvGozOq^4L<>{@Q2(t3zWmVo>7Gjcs8 z?OeRPrg9oN{~DT5|BA_%*jUo3Eqn9v+~|S)ZJz%XCc?{nkvSPOV$AzEXm3C!VS}L+ zZYBd}7sL`wuzbInOBFyoC(YE$_A99*)v)22dAg z6svZ$W1fML8`d)Ol)W0t%%FRqfY7FZ0%H9%Z1)I`VW;IWtTs@WU4&J2og5R+s@bly zSXGUzSl)KPK9*U{LMEm(pV{U@`YnWcA{uQBi$AT}rcDZ{bnvO=SY zyETob7gH~dzH#%zzPs+t-|tkGp+Z;x0{7{D00rMIG9&!nHpJd>S+ zx4{-P?DGK)gT^s#dlL;q`&-d4^)q)Y#Jml+1Soka3vvW%4>FXC$WXc;)qM_aKqIN=N~;34g-4NOAZWZ87SGomrzYkwkgC!^ zQq?{GU%(x(4F4Q|>d%<`Z%jxC1qJt!E{&Uig=Yz=v3d1}Ll95d<~KpB@HJm*-+e{c zyCUgQ{~GO24G$s<>UVhSh+y{eG@byy@^Kb zzqSYaua#V=-{pAvO#TfMfrS_#h>?H%x-jyqt~{P0a)z<&-}T;4W>?6-pWws=-c%Ie zK(r~sAc^}o3i}n$&KB-Jd2*)wj>jH4di=!kx6kP$Y?|hrdkX# zIl4gs(Z&XuA-~XTf$3L>Ra3fINz+JqEJ*YN$izq#>Iw^nmlO;yU?8Zhr}AW+h7pL@ zNNq3K0J8wH4w4{d7pUt|Vs@2eHkD+dM9JX?RXa6L$PIOp>skku7k1sAq3RHyKmOb? z#ZkfPtKm;mIZ?LW$-{XjbLQ|Xu(+V*+bBs-7c(wk@!yhHzx*N^OXHm5{_fqTw!{{P`6}gnnT?} z2j-4-qcIA_TgIa^Sr)`4M}uw*x+R^!cJP8EQ=3UH1O zPY)V{%QOYIx1EN<$Tt!9q%jH>=`6K0C#H2NmnaYX>msAK=Z{~$26Ip|{lL@TR(B`nmf=3I8>z(F=N!vmW zM+Ku`iG;t z>1kQqK%#=uGhDzuC_U?^11@={2dS_%{sgI=rab`=jxpR%yXG3Z`^FrQKuv)cIQd%)F-@JWyylxC?nd7ju;AB@_%o*MH497{y7V zYff+qgtvgHZjS~-7~yB(pL@oWxP-r>__>cE>Ch?wL@gtAK>XAVbpYdv5d`7}lAeHX zGNokL5VWjo0cA@LWr0ME`Zl!wK=h%3C8hiSGH#*1fy6g}GtSf&SJXFg1Vfh4s^dJ9 zjYey>YaC(tRDZ;)V_1w9OfYZ-jNLX~dqu-(G`P`{ z=vfT}EwYD#U7kk<-UL_?(CB;{{Wr7>u2~x9$*RD)AP0sr3HCv}=hahQ9{v!SwP6Q# zx-~Ecr~!>I{2{PY0Qu2%dunyw=8l&P5k--xb!FsUz=xSzW%3l0 z7{K4hLo(k}*%+Z|?b-lJFCGY8q+JFoLTWc~j5KKv3482JK(Dp|ZUfWTH-gto^07lO zkz=<=g5dx#qFAyCF_NJ0;~qv5Dbln_E0p5$p%mu~Z!&Hb{5=D8C(-H^?qufVu_~b| z{N#{{7H=;~^h_viX&kWv)vX}Ih6TcwY?|5hbpvYepCl+>ob0FA6zo@ha{ z3s@z{=Y)08`Jt|S9$bBZ5;yFbjieUAVkn8!C$L*E?^|aUSHfxoB()xt$`Kc}i=U)@ z`<0A2id+i+d3fsPjrOfqp5z!s z6CH;i<|Bq>7Y--@!N(P>$sE|ZH51Yw*q4j&Y8G$X4NE!=XP5wM_&MrVVbp~;5>7^J zhhC+ErXpPBXM&5$@8lEV2PW&P{SL3GZrN{Pw9rfRS7c%I*i@4pGk`@fR+`14hpuL> z!~q~Gr)M~{34lDnp{@%AzX8VP_s~NK{=)>4VhAaUP9_Q)hM_cMGQou241;*7lYG*Q zJn24T2h{fM{a2Z%;CaErRdCtN$&<&RxgVsw*V>HTyzxg-b8E#tb{{A1WL)q|K>7DPjFKL;JW+OpdHxg!F*63tGaO)^VZtae31<81n1{8$Wy}M-1y|lq4k_F) z-H$=_ZtqGo&Cbwj0sh-r?Ed@=TrZ7Ip_9Kz=5Wjbu0!dUdrFJYo3QlX_LPo`#wr8f zc}`@GNuJ4GJh)#74Y%5j*B`0UHChn2-Y_#X2W%BW%tRNf_yosSjI{#1*j64us{pTX zj$$2ADPmO3Cuod=G*s7k05ntE6 z{+#!UOahV6LWxgp=S9%?4CG+nOVFOi)d_2B^q3NdR$x;BYxdDMGtqwCrLJ;yDq8qIy1uJ}a%6V){jZsS=}2PgL$V;L|d~tc2HP% zjaBn-RA2Yy7y9al{fMqYKPCp&pPC*~9wu9{nZL;1XV@hpU^8zV zxQ?d7#naDSsx8;fFQd^#7#PB=2+!4h(hvE>P_nBXCwPIOrtlYa8yhaP_9 z@qtsz=G+pJZWd7Mr5$KawQgD%C@|}LQ4d^m4S$+77A{A8ySVSI0jE`xUhJ#G-YR%O zr}@wA2Lqt>OMGC8<-}d>g7h$kLID$=#D`sE#iRJS`;h=Ki;@eZQ@G~KCfuCFg-R>Y zsqzJyhIxP;RyQ%Z0SPqKw;9NPhy}2u!7zd(=(KF9goc;#ItHNHFUQ^8VYG*JcY~?k z+T?H$?SC=6GGpXz#ga{KAbdi3IwD&>9RsR7F-mnbH`tNSf!z<}4^L&ZkBH0YEZpQu z7+disjU`0-msU!8^F&E|zj`aGS*m;5krhw73tiyWN3X)ovEptAJ6K9z;~UHUkN;cf z)bo4@vSSGww8c;?tMah>BeRJDnWPXjonHMiP2t_ZnNwt(XI|j7#{{GIsExVc_~Qt< z^GLeElq7xv?x(Q5(&ay}f4fIK?3ut%?B1|mi#yx8+01xkGg7dDrdb2q0APEiU0dpM z*Rn~AYdIGiQHEkz#~$&fr!8^<^(}NpyEWPmKlI4sGO=RI_APA44;{#bNvHcac7tMH ztQtZ#HfGP}0xxEItqoWszspf4OmK~0qqc3nMAZa*5v|NiCcF{pvXFbJJ<+968N7zj zkc^=~8PQC*hC*|Pq27S5LBA98&I4A#pRq%13n+ZN`=%k%!RoC|JQs-qz|q}OYKi+a!q(5QQNbY;!u zh4g!O>u1<$Ty^NdgQdgx!D;@E2d7KN=kNz#cX+tVvzNL*U8B&dR#z073m*mE)VF!( zTP57E)^srl!NUk5)huAo@+3y!&%)0=h9q8|SZdtFWG_!jjti7~fw{n-CV)+o`+-Gf zX;_ell2BpN-h-GuvY%J9;BW!>B{aH$5LfsUl7n67p%Zu>;~jY;V?#FGqcN7g&Ee8c zC;uxhgQXv(8z~Lh_W>N#15A!GA#}(+1}(q;Ob9KY2`)cbct~xdJr)Aj)@zM}auh$8 z&Ma{4g~*36L=g5COn2g~37AUo=#B}&?H!(l%`d?8sDn+i@-G76gwER(awZaUmiUZV z5KKn(6`=DlrNC!{Z6(z*RiD#}1mW&grCB44^yUQ)?Hmd0{e2~L;Hs;EiO24` z{dRSPPq!CM|9=Pm>dn08y-XfrGJrCSV8smr&B>n5Ujux82v7EHhR^+*;M0J!t>I)P zq6gx7%mK~=sQsQksD)@Pl5;QAs$(36hnXBlGOby6gnjl1U1E0yAtN_1l zy~A{UF!(6$n_}}_Fvzk^FeuUjdeAH8l6>WbVb^M1JJ?N{45CWLyGn=s%wsiIHjGpV z^qP!u3sT*|j@V$Mls8}TO4TkRhW$*$k!NZ`Bd~AgG$=Uj7aU@VmH&z%z#k`R5vT$w9t_^WV9E5}HT7Es@Wg9eKfQ*U}20LIx z5GZsq%MCG2z>`5eN23Ca5ov9J=s7Av!f?gL79HE-`EQSx5z2F6zDx{f2Z3%m8QYvGt!G8(&3Cxa`V^L9&%XcZdSAtx(* z+l`_l2zK!a^QIB`J(wyzpNLv?Cl1d1=#TL8^Y>yy>p4uOKgRVbdnwD@vTQ+n|+4@Sp+lT$ligc?i&R4 zPL`2X^UalIu)utT$5{4uCIdK#yGX%i(aS`E%q_pczp3{_6<y5;IhNEM#gTTD zFTQNP%$*K-xym#qUNOI!%F;;AopbqDa&{5^9__J^e8r4N$d~cP%qQ_Xik~nuS}3K_ zah<#GjSBUq5$3#Kw(3&u(P#(wwNg3>Ii;lz5`$|0#TvHQVq0%%GbPg4g_2N!A}>HF zYnI+KQ@T((iXZ%Wp9tC`HJXEO>G38bevRg% zgnMC~C&Vg+I8%?0cyKG{|3X52H+K8uPOIs&r*XHe`YKZRn#sL$@F0^ou5gvdzDAF6 zeYgV*er1;}C;@YbaET}mh?Q!0a@!I2p)&+q8Qo%ZUy7gv*uyCO%tbv0o`cN@ac)`O zC;-Rp_D0}TgQyUQjA=>jQC(Zgvrr*{rf5qLwTn~cbPy;)nAg21(!Sq2;N(4rosscY z$!cki2as%^My>V^?xU2Fz==GpJvF8t#|)^GOy0%h-ArbYATSl9m7U;SDBWmTO%T9) zc-q7+ShDr%6t55@z?G)TI1KR!p+}wO`6`mxH1Hol-Se3tz7f~^*)*XWfH*6`Wr$XF zj#b6IoKeq!ar?oJaQZaCPHe`&3A?DL;+KJAgdt~{#6bX_fEO^*f}HKf86jJqnK`CO z$gyK^)z}pK;E(w%b4Z0h-5;?WBaj=O)9IT6i_6gx94?A>b z;#xP=yZUrfh%y5UH~FGWvjhwQo~!|Ty*o!m4q}cWt~szpQe5yvf5i}kkpv5kdyEzK zqF-Qqei477c0&GNGS33+Ymd}i!uDwGRIIFOuqOExbqUGrQ03ei1R-fHf{B<4_DggUmOgK zNvvSv!)z(X{~Y@#D-YEm}9VDQx0FyQ;x?7aVvO4`0uOn!J!(=q(+@5B?H5DF(;{xH#^~NNN5}o+_IB znQAybxPv?wm)ueOT;3fE1B?Q+nuqm940lqV6rm09d?pmV@LeSvK^rACmk>gD1szyU zBmOJc2@n3ZPWox;q@vmkF#?j#aO4o$Kz6e;Y7r0j!pn%xEV?1suWW|=kY`p)I7hM; zeL^TD3oRiO5xKg+ww%PS#Jhm~&zkaj1v60ei86C1D)$jKKYQd%ty#OUtnT?$jDy~I zb~AV?JxBpWc5OKeYfl;ekY?zFO~b5laY`j>2n5Dfpf^&|9(#!qUmAv8zF215!> zJ4AAdC0ZaK3u>`-ltuNyUc#ElyD+#_co+5xssM&^SI?s^lmN{<{&iopaJ>yzO(GCv z0{blyK(etT+{`tV)`4;QL}~$4QG-tH)EHoeV;Erogy?3Ja&hewWFYLJ0G&{gND5Xb zc%2BrDUnnU$yECp@OT2%TE5lI@iija#;NNbR^N%M{Y)5qTfKr=Qdm#$jY1LbF4mqDE$UOkN7fyU13@_MQUlebo8Z^&6pi2WAf#&4BH34nM#(d_GAxab)l^g zRM{v2i)7#^e(wK*1ab)w4yRJ~l8J;PWt4V8$ph_xc}38fBdU!pfuh(u*Hi|CK3&(W`rF!U+@E z(x@p9Wqt-Bl(eSfPPRSI$^+2OUvkbSQG@j37Gn?Qr$`=O^Y>7>_s&d2y}MTb6;Tr5 z`+44jTJ;Ry<@8Enw_<6PH|_p}bKScsA9{6c&&G@hT1sys;u*hie;CbI@X}0`!f^Ou zoBX^1SQ4R(2wEV2(T0slYz6&XjEI*978DE;@UVsY^*U6VU^U(89V|$ zU-_I4D<|w~#HV5xat$ZdQ=Du;D!X~etsZ`sk+^~w%KQ@?^J6@H9}@u*!7)^_sppw! z^uo}}3+DoT8H7xEPXY?WC#Mf{1U|gcUDf;8lka9S0AIXFxM1fKU|PAAIK~KbDj@*n@opT%Jw_~I;->;NBl;gW zZ^RpEjCMwX>qn$^9M9O;J|V8_MVj4yHiLP& z$xwUU+dpvw64WwttcR;3`i}4@XQoF?XK+_aLBMo^LH7}Z3P%S8BQPWRDjnYhibNc9 zqr54h=Mf}6!k7eAV$l>G!l58UYl2A`NV$&$wd7}^diQBJ&{bPJ@@lsVBot* zxT{iAL0-~}<@}tI7q0QW4}#s5OXGz0od&}!p9C= zkjUmFPy%=%iCl)s6e`=kY49Ax_${1VHnI)B2)}(Aq7m>L>_dTz3je9+H5!IREpLjL zP52$*LEvZyG6el6c@c0t2qHr5tAsOwJ{A!R5Yy`ec=bXWuU_~7LWF?dS)~Rz(?PTX z;3f+eBj|iZuQ~)A1TF|Kp1?Lohu-|C+2Gs35V>8-hI-0|r3~8Vu6n?8xF!b>$n#xB zjUn9sFZmCQhc}J^x9a>b^26TnS^@9zLL_)bUPpa!bI|tYIv}CosYU=?c*`A1&H?{| zyoNant>6Rk)y@k%8qq_M2OLm1nxVA==GIYAk#Xt!_@KUHv_=}cFy~`V@v^yg!+HWSd|9?z z;K0Be!Z^<3orz8XebIA10hj^B0UTg_Ho-xnaU=TlG3eQ{9R$3=TnsfcSoYN314=+$ zwgt|-#Ql9*^Dx>VJRCIT5SEyfgB5ioLY&iKRswCi4$^|QP=^s+x&qMfaGn8ZgrRQl z)yEPIMrVVFFkyK^y_3xpPU{;SLOf!8+oOy83f>;oQhW)atBEC5Cn7ocLtSBF)L*k$#EAcsxsyyDM*@o!tf(72Bu;~= zLe&>R`WR5=5DGMDCC&qjNp*~NyUu_BOB`~)IM=GKxO@qi3#mgeuAe&x5d<+oiG79s zlGggEc^!j7P_q~^k~C{fBy%9zFo@a(8!@+W{vg=W-3Tr^ ziMMqyL>#^^*{NL+k4nZaYmzl_ea;k92uj;~wm|A3;ZJsx*kjzHa|I^gSv<-?c6&m4 z4e3dL|F%$)+&2_!Y+?#Xtoiv?;ysUt{qG8X`{EbH0@79T|HEj{TF|awSQxB1F9W0# z<~d~Gw&3}Q0qx>#fc)SN3j$XL2ZIzG3U<}*cMxIBOCo?X{Ld? zj^gd2(H??(&nGcH*AXQ%hw9@T(p!*7l+4`{CDY;%F)pT#l8NY;DIFbi!+^e2luIkT z)909cp2-)Gz!O+Xzr+Hmbe_3?#pG9*$ci6f?nWeE=BnYO>cGn8EA86{0CvP`bt!)t zYEws6-Vqyi{9*0`s3YijnL981;Z_sba1-+E5Lr!81H%>f%k0CiGWj(oBE0nRoDyjH zkfVNs6&8^AsW}D;(kr@6MMiNyvdF)IP&l!0(VywfY~~g3n&G}UQUMR8zX4@vkue@kOebT zd4eTGh#}@^1==Xiytv7ILI5|a;+ ztRrqwmn#Q=6P+Xn$!ki%MS+ip>>bf3=1Ow#;$)Cv>4CYDn$qQPxyPXW7+1;f0WA3t zQQyJ_#A8JLK2IY&qP~+f!vi{$%)Q9_hL5W+3wYZNODib8hO4!1qJrGUsFHOWBNf`5 z)bmVaDPx^N!^T9A$r}52D-Q){2=@5Dc=%N&T2D_XtW7{q0LjfmqIY9 z@ZA)Ta|qoIltF&#hCxUpWTrCRME#hK2(tAqgpc9r|0%mC1RuK)gok>m@EXN4o4BZu zMd%Co0ppn12Ou@k=m`&kI94~Jb!#8-D@E&$IK(b~sw=pH`Vmgy9u9Z~378dCV)sTl zroo4Kc278#T{wG991G$OUBZhcgk!N0v>P}QoCbqdLko?d-B;sSsLUM1vFM=P*Tk_r zlZpa)^Hn2*IG4W&I2T~@_moSj`HefSyF652AZJY}?4+-aT5ShOC-buvw#=s#u zS{6JEUgCzBs9VMN77pgm(1wn<{pZZx%Gv!obHdf+bb*|TYhDkhGYu_~cmSj<^Cj-n zv0rY#X~3fC0g<78#28N#bV~CgXqR(??9k7`rY^M(|Q9G8$KBGIm?3 zmk_c$)gQ83n2cMm@zHy%L`BR;#Bc1Ta_wsT1`LCN@xTy8*|(41pfZa_1*>72wHgl6YYvqzZc-oDa_Cp_0qc=zyPAlVQVNgDSCt=0`4uoC{ZJK{%I5 z@ErvzfPe*H`yy`w{^d^tUIiwHh<|xb%CP4)E&~@5wxe`AsAF~tk&-Xqq*w<2WdwG5 za!U`ajw0}&)a#6*F5N!a=a_%V2mA*VZ|E25bjHAwkbfB!{^ecZVP1{;QlOq72Y^-Ve`=t+9{ zv*>x2y(E?8W<}y83O)8W3e$>bH?ayfp9h~F?=vSOrt8q7M*O3?;yYyhY-D@s3y9o< zo2c|p;U~17Q`1H3wG{;$bLqX+Dq$VtKTn5f>JUsw9Hx#c@7YlJ(|d*#KZPT19svUb z%QX6>T_CZM<(vnjC4OnF3U0HpVRr%E!7hAUL+=_~t?e)2KGpw-4#lAwM3Flnviv-2 z-N~eZ#802l!K(2tDADV6*m(V7{vvdnUkSPA2ufe7@8)^-{YS%|=8Xt5lCq{j zP>cRosC(rg?w2qcfiGP=bip#~)~D$B?d==bauA(;K90{WulNg3O1$;HlMY`@6zR{S zzD^Bk?YV%KnCw2p9zV=_=jMW8IsF-r8 z9@HjKldn4iPW38TvxXGX5o=8L(7;fR9UTr$cA7clP&l@Y-Z&IY@8A zPgXiO1i-o8z%pml@zSnO8y|-Lfo4nHSyM?L*t@W2ga|*ifOi+48TdZPA7rRi{?^m^i+JYmkEat+WmrNj%COn?h+>7byz^`1+ug+ zK*mjcU2opoxK|jJD#0xL+MW*z5y7?k?h}2IvIA z0r6R|ArrtrJ~D|Ix6qS8zeUSGZ2On3GjO@12h-VGM3ZdIV7z73)^R&gPhyW0u{oA> zjYTm4{U#25?X>9XR{~A3!hR+kJXsJDgam$b!p}&<1TnZ|X6c*3THTA6oBMgAik8=*m8)z8CJC;Rnk$<^8*)s`)L6u<2!9I##?0g+~ zD`IGOgPjH$;c$x2B1k%fU-2D0Tfm@0k-1>2(*$r37jetJ%2f|abo@Go3*6}!Icn4k zhc4y>dX3#ic#5#m$FY4ndS2|qr;G*djQqE^^*(RNVr4Xj`#jtG$P$a-Htq&ukFRAgS62C6JoI3q=fnJ7Nft179;hefC^xbG-2xt z1YNOW2oMW(rYY#Fgd<|r$Q1O(0kT3g_0h8J=3zpckSIsgI&UYKq1R~gbG(!2GroOB zlUH3+Uwg>+f482*+K2zjtS9PI1ijBN`4>zEpzA*+5FP`#UrE5GGX5GB@V8tM0Arjl z2)%;aV%z7{H1WTRzW;OflR_0mY;Y+?f|_`q*@WhWAwNCPa$b%{y@1Q4hbP>!BnWcc)>5<0B zJaWoJ8smV8a(PnB;(TK<-(HNj7+(ww;xfHJZ_i>-K)cDS$-5k23l}cB&VqR+eLji3 z2=FG1pU4m?19tE?g!4IIP_Bsk!cA*)ykUb0*)#o0;2fB#AY?*M3_aSF zLVq0(7psKF*;hhEBFETMc$@1hcqwuOgIDTvkJf}B!=JEO9BDF4V3dk*nuKlx>=ksL z7vU0p7PjF~CNR|V9O)qD*l}E?et^CEK_)+h#81j*n?cL&W_Fx+Or-f9toF@U8s>1- zq`IN-n1JC`Jc4KpFY;ECxKqZKb)$WP{-&TeMSnxF&Zad)zLTF`vEphfs%^*d{;2Rt zj)s9QW<+p3)-&K(&GELc58(OdIY&c zG$n$mn4m!N7_T5K^gYPx{aIsL7WoYp%<(+2fZhouzQ3{nS}gl@CbA!kdMo`B?K>RB zjpKbgtK`L7VW9K4y>o0p8XX{Q9O7hai&v4(%D(kWx8eN-t@0UoHoG5(PO9}#6aU_eOni0t@A_X{po`QU1~FRKRgKd)o-!k0bu@4&ftFBIswk4AI!OTun*vO zZJL;C0QNVa`F>yzf}9ZSO?7ZTN4885xDQ% z-0y1#bC{rJ7u1}J)o}s0UEz5V&8#Ee}91cI+p`%q1Sl@HXAfXu?YYKMivmZ z0tlJuE*tM%K*764IGWeBfHYu<<%7_r*V7&P@6PL5N-Knrdj#~Zx|9%%n4qaQ;&$cJ zpeZ+Q0#yXKx`tKu$Q4&v_&P^uXQA2(2#&4p+j`~kMT!-)i4f-T;`@D$@7I`IaT$J; zV?pbFxD4s59Z;S~0Tl5ye7){Y6rm)-m$rh%*B`^G2toOGy&G-Y(#vCq@ zfN@|CL?L@8O9vqQ9}h>42){iRK2ii3BOCT@WD5Mn1O zTtl%pM*>;wqUm!dc)t!JlIc9Q_}Xirr~tX7$9IWh*Eg<9E+7CeoH|n-8bxu2d%9t; z!;VgI#Bktjf&38=4mtMZ*nm@mgUFmC9h351cg611QK*mguAyFhkDh>pqW` z;qD+;u?Eg!rzduQ;$GPJMw9Eh@hr@-e8VDLJ}G0)nGz&B)3J0g4eZs5pYdh5E$C;= zYBONO%R*#4fuT7j(X3?cM8k(-iT;AoPuzi@%ZoaRCm@5u_uN=Z&@Fo%d-H<$8j-!k z@Uy2?tK!Jpig)57_(SFbr|5c3>B){6zGckHmvEE7Kp0JLu)Ml-Yf-Qb#7*^7MM7-r z09h@94iTAeD*&y4m2bH#OMa$VJ731;Q&)Dc$i@|fK0qA=*jkdEqRPd^p;5qRL;nT| zwzSZ;a~y4Unf~6xTO{jEr-rvH>v$iUr2i4G6FnmN*PkN9-hvF?fo$xMeIpYVH!jN; z8c?dwzXE`vZwE{PFODX*mJIXJygP#d(SCGCLgA0yKI;eD!uK^(426)*MFEC`%>9aq_p% zRq9G&S)wn~adcKi424LSUVYJ>R<$!wBcqAxV#Re8A6@NW!ueNkWkS-aUsiAmIY0Gi z%UMB;+9z222_~Om@+UlJwp_jh{ajtOns~8e%~Sse&;17`d^?%EGE*zY|AL3V%;X?O zSnne!F!v@Vk20BI@*PYXOjejY#e}LSagPk+(jh8`oXHv?`zC+3pMjr-p=JS?@Nql3K=Vt%nTRrE*>b}Uwn|B z!Nog^+l#js_ZA<1*DmfVjuj8$?2|>OxKP|v%oMj3tHm3OrQ)lKxBou~ C4#UC# literal 0 HcmV?d00001 diff --git a/tools/modules/unet/mha_flash.py b/tools/modules/unet/mha_flash.py new file mode 100644 index 0000000..5edfe0e --- /dev/null +++ b/tools/modules/unet/mha_flash.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.cuda.amp as amp +import torch.nn.functional as F +import math +import os +import time +import numpy as np +import random + +# from flash_attn.flash_attention import FlashAttention +class FlashAttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None, batch_size=4): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(FlashAttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + if self.head_dim <= 128 and (self.head_dim % 8) == 0: + new_scale = math.pow(head_dim, -0.5) + self.flash_attn = FlashAttention(softmax_scale=None, attention_dropout=0.0) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + # self.apply(self._init_weight) + + + def _init_weight(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=0.15) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + cq = torch.zeros([b, n, d, 4], dtype=q.dtype, device=q.device) + q = torch.cat([q, cq], dim=-1) + + qkv = torch.cat([q,k,v], dim=1) + origin_dtype = qkv.dtype + qkv = qkv.permute(0, 3, 1, 2).reshape(b, -1, 3, n, d).half().contiguous() + out, _ = self.flash_attn(qkv) + out.to(origin_dtype) + + if context is not None: + out = out[:, :-4, :, :] + out = out.permute(0, 2, 3, 1).reshape(b, c, h, w) + + # output + x = self.proj(out) + return x + identity + +if __name__ == '__main__': + batch_size = 8 + flash_net = FlashAttentionBlock(dim=1280, context_dim=512, num_heads=None, head_dim=64, batch_size=batch_size).cuda() + + x = torch.randn([batch_size, 1280, 32, 32], dtype=torch.float32).cuda() + context = torch.randn([batch_size, 4, 512], dtype=torch.float32).cuda() + # context = None + flash_net.eval() + + with amp.autocast(enabled=True): + # warm up + for i in range(5): + y = flash_net(x, context) + torch.cuda.synchronize() + s1 = time.time() + for i in range(10): + y = flash_net(x, context) + torch.cuda.synchronize() + s2 = time.time() + + print(f'Average cost time {(s2-s1)*1000/10} ms') \ No newline at end of file diff --git a/tools/modules/unet/unet_unianimate.py b/tools/modules/unet/unet_unianimate.py new file mode 100644 index 0000000..097b52f --- /dev/null +++ b/tools/modules/unet/unet_unianimate.py @@ -0,0 +1,659 @@ +import math +import torch +# import xformers +# import xformers.ops +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +from ....lib.rotary_embedding_torch import RotaryEmbedding +from fairscale.nn.checkpoint import checkpoint_wrapper + +from .util import * +# from .mha_flash import FlashAttentionBlock +from ....utils.registry_class import MODEL + + +USE_TEMPORAL_TRANSFORMER = True + + + +class PreNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + +class PreNormattention_qkv(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, q, k, v, **kwargs): + return self.fn(self.norm(q), self.norm(k), self.norm(v), **kwargs) + q + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class Attention_qkv(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_k = nn.Linear(dim, inner_dim, bias = False) + self.to_v = nn.Linear(dim, inner_dim, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, q, k, v): + b, n, _, h = *q.shape, self.heads + bk = k.shape[0] + + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k = rearrange(k, 'b n (h d) -> b h n d', b=bk, h = h) + v = rearrange(v, 'b n (h d) -> b h n d', b=bk, h = h) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class PostNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.norm(self.fn(x, **kwargs) + x) + + + + +class Transformer_v2(nn.Module): + def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)), + FeedForward(dim, mlp_dim, dropout = dropout_ffn), + ])) + def forward(self, x): + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + + + +@MODEL.register_class() +class UNetSD_UniAnimate(nn.Module): + + def __init__(self, + config=None, + in_dim=4, + dim=512, + y_dim=512, + context_dim=1024, + hist_dim = 156, + concat_dim = 8, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=1, + temporal_attention = True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition= False, + use_sim_mask = False, + misc_dropout = 0.5, + training=True, + inpainting=True, + p_all_zero=0.1, + p_all_keep=0.1, + zero_y = None, + black_image_feature = None, + adapter_transformer_layers = 1, + num_tokens=4, + **kwargs + ): + embed_dim = dim * 4 + num_heads=num_heads if num_heads else dim//32 + super(UNetSD_UniAnimate, self).__init__() + self.zero_y = zero_y + self.black_image_feature = black_image_feature + self.cfg = config + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.num_tokens = num_tokens + self.hist_dim = hist_dim + self.concat_dim = concat_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + + self.num_heads = num_heads + + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + self.training=training + self.inpainting = inpainting + self.video_compositions = self.cfg.video_compositions + self.misc_dropout = misc_dropout + self.p_all_zero = p_all_zero + self.p_all_keep = p_all_keep + + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + self.resolution = config.resolution + + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + if 'image' in self.video_compositions: + self.pre_image_condition = nn.Sequential( + nn.Linear(self.context_dim, self.context_dim), + nn.SiLU(), + nn.Linear(self.context_dim, self.context_dim*self.num_tokens)) + + + if 'local_image' in self.video_compositions: + self.local_image_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.local_image_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'dwpose' in self.video_compositions: + self.dwpose_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim, 3, stride=2, padding=1)) + self.dwpose_embedding_after = Transformer_v2(heads=2, dim=concat_dim, dim_head_k=concat_dim, dim_head_v=concat_dim, dropout_atte = 0.05, mlp_dim=concat_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'randomref_pose' in self.video_compositions: + randomref_dim = 4 + self.randomref_pose2_embedding = nn.Sequential( + nn.Conv2d(3, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.AdaptiveAvgPool2d((self.resolution[1]//2, self.resolution[0]//2)), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=2, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=2, padding=1)) + self.randomref_pose2_embedding_after = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + if 'randomref' in self.video_compositions: + randomref_dim = 4 + self.randomref_embedding2 = nn.Sequential( + nn.Conv2d(randomref_dim, concat_dim * 4, 3, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim * 4, 3, stride=1, padding=1), + nn.SiLU(), + nn.Conv2d(concat_dim * 4, concat_dim+randomref_dim, 3, stride=1, padding=1)) + self.randomref_embedding_after2 = Transformer_v2(heads=2, dim=concat_dim+randomref_dim, dim_head_k=concat_dim+randomref_dim, dim_head_v=concat_dim+randomref_dim, dropout_atte = 0.05, mlp_dim=concat_dim+randomref_dim, dropout_ffn = 0.05, depth=adapter_transformer_layers) + + ### Condition Dropout + self.misc_dropout = DropPath(misc_dropout) + + + if temporal_attention and not USE_TEMPORAL_TRANSFORMER: + self.rotary_emb = RotaryEmbedding(min(32, head_dim)) + self.time_rel_pos_bias = RelativePositionBias(heads = num_heads, max_distance = 32) # realistically will not be able to generate that many frames of video... yet + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + self.pre_image = nn.Sequential() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim + concat_dim, dim, 3, padding=1)]) + + #### need an initial temporal attention? + if temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + init_block.append(TemporalTransformer(dim, num_heads, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) + else: + init_block.append(TemporalAttentionMultiBlock(dim, num_heads, head_dim, rotary_emb=self.rotary_emb, temporal_attn_times=temporal_attn_times, use_image_dataset=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + + block = nn.ModuleList([ResBlock(in_dim, embed_dim, dropout, out_channels=out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,)]) + + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, + disable_self_attn=False, use_linear=True + ) + ) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append(TemporalTransformer(out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset)) + else: + block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim + ) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False, use_image_dataset=use_image_dataset,), + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=self.context_dim, + disable_self_attn=False, use_linear=True + )]) + + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + self.middle_block.append( + TemporalTransformer( + out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + ) + ) + else: + self.middle_block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + + self.middle_block.append(ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) + + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + + block = nn.ModuleList([ResBlock(in_dim + shortcut_dims.pop(), embed_dim, dropout, out_dim, use_scale_shift_norm=False, use_image_dataset=use_image_dataset, )]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, out_dim // head_dim, head_dim, depth=1, context_dim=1024, + disable_self_attn=False, use_linear=True + ) + ) + if self.temporal_attention: + if USE_TEMPORAL_TRANSFORMER: + block.append( + TemporalTransformer( + out_dim, out_dim // head_dim, head_dim, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_temporal, multiply_zero=use_image_dataset + ) + ) + else: + block.append(TemporalAttentionMultiBlock(out_dim, num_heads, head_dim, rotary_emb =self.rotary_emb, use_image_dataset=use_image_dataset, use_sim_mask=use_sim_mask, temporal_attn_times=temporal_attn_times)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample(out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward(self, + x, + t, + y = None, + depth = None, + image = None, + motion = None, + local_image = None, + single_sketch = None, + masked = None, + canny = None, + sketch = None, + dwpose = None, + randomref = None, + histogram = None, + fps = None, + video_mask = None, + focus_present_mask = None, + prob_focus_present = 0., # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + mask_last_frame_num = 0 # mask last frame num + ): + + + assert self.inpainting or masked is None, 'inpainting is not supported' + + batch, c, f, h, w= x.shape + frames = f + device = x.device + self.batch = batch + + #### image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like((batch,), prob_focus_present, device = device)) + + if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device) + else: + time_rel_pos_bias = None + + + # all-zero and all-keep masks + zero = torch.zeros(batch, dtype=torch.bool).to(x.device) + keep = torch.zeros(batch, dtype=torch.bool).to(x.device) + if self.training: + nzero = (torch.rand(batch) < self.p_all_zero).sum() + nkeep = (torch.rand(batch) < self.p_all_keep).sum() + index = torch.randperm(batch) + zero[index[0:nzero]] = True + keep[index[nzero:nzero + nkeep]] = True + assert not (zero & keep).any() + misc_dropout = partial(self.misc_dropout, zero = zero, keep = keep) + + + concat = x.new_zeros(batch, self.concat_dim, f, h, w) + + + # local_image_embedding (first frame) + if local_image is not None: + local_image = rearrange(local_image, 'b c f h w -> (b f) c h w') + local_image = self.local_image_embedding(local_image) + + h = local_image.shape[2] + local_image = self.local_image_embedding_after(rearrange(local_image, '(b f) c h w -> (b h w) f c', b = batch)) + local_image = rearrange(local_image, '(b h w) f c -> b c f h w', b = batch, h = h) + + concat = concat + misc_dropout(local_image) + + if dwpose is not None: + if 'randomref_pose' in self.video_compositions: + dwpose_random_ref = dwpose[:,:,:1].clone() + dwpose = dwpose[:,:,1:] + dwpose = rearrange(dwpose, 'b c f h w -> (b f) c h w') + dwpose = self.dwpose_embedding(dwpose) + + h = dwpose.shape[2] + dwpose = self.dwpose_embedding_after(rearrange(dwpose, '(b f) c h w -> (b h w) f c', b = batch)) + dwpose = rearrange(dwpose, '(b h w) f c -> b c f h w', b = batch, h = h) + concat = concat + misc_dropout(dwpose) + + randomref_b = x.new_zeros(batch, self.concat_dim+4, 1, h, w) + if randomref is not None: + randomref = rearrange(randomref[:,:,:1,], 'b c f h w -> (b f) c h w') + randomref = self.randomref_embedding2(randomref) + + h = randomref.shape[2] + randomref = self.randomref_embedding_after2(rearrange(randomref, '(b f) c h w -> (b h w) f c', b = batch)) + if 'randomref_pose' in self.video_compositions: + dwpose_random_ref = rearrange(dwpose_random_ref, 'b c f h w -> (b f) c h w') + dwpose_random_ref = self.randomref_pose2_embedding(dwpose_random_ref) + dwpose_random_ref = self.randomref_pose2_embedding_after(rearrange(dwpose_random_ref, '(b f) c h w -> (b h w) f c', b = batch)) + randomref = randomref + dwpose_random_ref + + randomref_a = rearrange(randomref, '(b h w) f c -> b c f h w', b = batch, h = h) + randomref_b = randomref_b + randomref_a + + + x = torch.cat([randomref_b, torch.cat([x, concat], dim=1)], dim=2) + x = rearrange(x, 'b c f h w -> (b f) c h w') + x = self.pre_image(x) + x = rearrange(x, '(b f) c h w -> b c f h w', b = batch) + + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + self.fps_embedding(sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + + context = x.new_zeros(batch, 0, self.context_dim) + + + if image is not None: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + + image_context = misc_dropout(self.pre_image_condition(image).view(-1, self.num_tokens, self.context_dim)) # torch.cat([y[:,:-1,:], self.pre_image_condition(y[:,-1:,:]) ], dim=1) + context = torch.cat([context, image_context], dim=1) + else: + y_context = self.zero_y.repeat(batch, 1, 1) + context = torch.cat([context, y_context], dim=1) + image_context = torch.zeros_like(self.zero_y.repeat(batch, 1, 1))[:,:self.num_tokens] + context = torch.cat([context, image_context], dim=1) + + # repeat f times for spatial e and context + e = e.repeat_interleave(repeats=f+1, dim=0) + context = context.repeat_interleave(repeats=f+1, dim=0) + + + + ## always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, e, context, time_rel_pos_bias,focus_present_mask, video_mask, reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b = batch) + return x[:,:,1:] + + def _forward_single(self, module, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference=None): + if isinstance(module, ResidualBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, TemporalTransformer): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, MemoryEfficientCrossAttention): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, TemporalAttentionBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalAttentionMultiBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, InitTemporalConvBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, TemporalConvBlock): + module = checkpoint_wrapper(module) if self.use_checkpoint else module + x = rearrange(x, '(b f) c h w -> b c f h w', b = self.batch) + x = module(x) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, focus_present_mask, video_mask, reference) + else: + x = module(x) + return x + + + diff --git a/tools/modules/unet/util.py b/tools/modules/unet/util.py new file mode 100644 index 0000000..1d6c6e6 --- /dev/null +++ b/tools/modules/unet/util.py @@ -0,0 +1,1741 @@ +import math +import torch +import xformers +# # import open_clip +# import xformers.ops +import torch.nn as nn +from torch import einsum +from einops import rearrange +from functools import partial +import torch.nn.functional as F +import torch.nn.init as init +from ....lib.rotary_embedding_torch import RotaryEmbedding +from fairscale.nn.checkpoint import checkpoint_wrapper + +# from .mha_flash import FlashAttentionBlock +# from utils.registry_class import MODEL + + +### load all keys started with prefix and replace them with new_prefix +def load_Block(state, prefix, new_prefix=None): + if new_prefix is None: + new_prefix = prefix + + state_dict = {} + state = {key:value for key,value in state.items() if prefix in key} + for key,value in state.items(): + new_key = key.replace(prefix, new_prefix) + state_dict[new_key]=value + return state_dict + + +def load_2d_pretrained_state_dict(state,cfg): + + new_state_dict = {} + + dim = cfg.unet_dim + num_res_blocks = cfg.unet_res_blocks + temporal_attention = cfg.temporal_attention + temporal_conv = cfg.temporal_conv + dim_mult = cfg.unet_dim_mult + attn_scales = cfg.unet_attn_scales + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + #embeddings + state_dict = load_Block(state,prefix=f'time_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state,prefix=f'y_embedding') + new_state_dict.update(state_dict) + state_dict = load_Block(state,prefix=f'context_embedding') + new_state_dict.update(state_dict) + + encoder_idx = 0 + ### init block + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + encoder_idx += 1 + + shortcut_dims.append(dim) + for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + idx = 0 + idx_ = 0 + # residual (+attention) blocks + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ = 2 + + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}.{idx}',new_prefix=f'encoder.{encoder_idx}.{idx_}') + new_state_dict.update(state_dict) + # if temporal_attention: + # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + in_dim = out_dim + encoder_idx += 1 + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + # downsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout) + state_dict = load_Block(state,prefix=f'encoder.{encoder_idx}',new_prefix=f'encoder.{encoder_idx}.0') + new_state_dict.update(state_dict) + + shortcut_dims.append(out_dim) + scale /= 2.0 + encoder_idx += 1 + + # middle + # self.middle = nn.ModuleList([ + # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none'), + # TemporalConvBlock(out_dim), + # AttentionBlock(out_dim, context_dim, num_heads, head_dim)]) + # if temporal_attention: + # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + # elif temporal_conv: + # self.middle.append(TemporalConvBlock(out_dim,dropout=dropout)) + # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 'none')) + # self.middle.append(TemporalConvBlock(out_dim)) + + + # middle + middle_idx = 0 + # self.middle = nn.ModuleList([ + # ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout), + # AttentionBlock(out_dim, context_dim, num_heads, head_dim)]) + state_dict = load_Block(state,prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + state_dict = load_Block(state,prefix=f'middle.1',new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 1 + + for _ in range(cfg.temporal_attn_times): + # self.middle.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + middle_idx += 1 + + # self.middle.append(ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)) + state_dict = load_Block(state,prefix=f'middle.2',new_prefix=f'middle.{middle_idx}') + new_state_dict.update(state_dict) + middle_idx += 2 + + + decoder_idx = 0 + for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + idx = 0 + idx_ = 0 + # residual (+attention) blocks + # block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)]) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + if scale in attn_scales: + # block.append(AttentionBlock(out_dim, context_dim, num_heads, head_dim)) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 1 + for _ in range(cfg.temporal_attn_times): + # block.append(TemporalAttentionBlock(out_dim, num_heads, head_dim, rotary_emb = self.rotary_emb)) + idx_ +=1 + + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + + # upsample = ResidualBlock(out_dim, embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout) + state_dict = load_Block(state,prefix=f'decoder.{decoder_idx}.{idx}',new_prefix=f'decoder.{decoder_idx}.{idx_}') + new_state_dict.update(state_dict) + idx += 1 + idx_ += 2 + + scale *= 2.0 + # block.append(upsample) + # self.decoder.append(block) + decoder_idx += 1 + + # head + # self.head = nn.Sequential( + # nn.GroupNorm(32, out_dim), + # nn.SiLU(), + # nn.Conv3d(out_dim, self.out_dim, (1,3,3), padding=(0,1,1))) + state_dict = load_Block(state,prefix=f'head') + new_state_dict.update(state_dict) + + return new_state_dict + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, + torch.pow(10000, -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device = device, dtype = torch.bool) + elif prob == 0: + return torch.zeros(shape, device = device, dtype = torch.bool) + else: + mask = torch.zeros(shape, device = device).float().uniform_(0, 1) < prob + ### aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0]=False + return mask + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, max_bs=4096, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.max_bs = max_bs + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + if q.shape[0] > self.max_bs: + q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0) + k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0) + v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0) + out_list = [] + for q_1, k_1, v_1 in zip(q_list, k_list, v_list): + out = xformers.ops.memory_efficient_attention( + q_1, k_1, v_1, attn_bias=None, op=self.attention_op) + out_list.append(out) + out = torch.cat(out_list, dim=0) + else: + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads = 8, + num_buckets = 32, + max_distance = 128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets = 32, max_distance = 128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class SpatialTransformerWithAdapter(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, + adapter_list=[], adapter_position_list=['', 'parallel', ''], + adapter_hidden_dim=None): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, + adapter_list=adapter_list, adapter_position_list=adapter_position_list, + adapter_hidden_dim=adapter_hidden_dim) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class Adapter(nn.Module): + def __init__(self, in_dim, hidden_dim, condition_dim=None): + super().__init__() + self.down_linear = nn.Linear(in_dim, hidden_dim) + self.up_linear = nn.Linear(hidden_dim, in_dim) + self.condition_dim = condition_dim + if condition_dim is not None: + self.condition_linear = nn.Linear(condition_dim, in_dim) + + init.zeros_(self.up_linear.weight) + init.zeros_(self.up_linear.bias) + + def forward(self, x, condition=None, condition_lam=1): + x_in = x + if self.condition_dim is not None and condition is not None: + x = x + condition_lam * self.condition_linear(condition) + x = self.down_linear(x) + x = F.gelu(x) + x = self.up_linear(x) + x += x_in + return x + + +class MemoryEfficientCrossAttention_attemask(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=xformers.ops.LowerTriangularMask(), op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + + +class BasicTransformerBlock_attemask(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention_attemask + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerBlockWithAdapter(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False, + adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], adapter_hidden_dim=None, adapter_condition_dim=None + ): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + # adapter + self.adapter_list = adapter_list + self.adapter_position_list = adapter_position_list + hidden_dim = dim//2 if not adapter_hidden_dim else adapter_hidden_dim + if "self_attention" in adapter_list: + self.attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + if "cross_attention" in adapter_list: + self.cross_attn_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + if "feedforward" in adapter_list: + self.ff_adapter = Adapter(dim, hidden_dim, adapter_condition_dim) + + + def forward_(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + return checkpoint(self._forward, (x, context, adapter_condition, adapter_condition_lam), self.parameters(), self.checkpoint) + + def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + if "self_attention" in self.adapter_list: + if self.adapter_position_list[0] == 'parallel': + # parallel + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + self.attn_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[0] == 'serial': + # serial + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + + if "cross_attention" in self.adapter_list: + if self.adapter_position_list[1] == 'parallel': + # parallel + x = self.attn2(self.norm2(x), context=context) + self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[1] == 'serial': + x = self.attn2(self.norm2(x), context=context) + x + x = self.cross_attn_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.attn2(self.norm2(x), context=context) + x + + if "feedforward" in self.adapter_list: + if self.adapter_position_list[2] == 'parallel': + x = self.ff(self.norm3(x)) + self.ff_adapter(x, adapter_condition, adapter_condition_lam) + elif self.adapter_position_list[2] == 'serial': + x = self.ff(self.norm3(x)) + x + x = self.ff_adapter(x, adapter_condition, adapter_condition_lam) + else: + x = self.ff(self.norm3(x)) + x + + return x + +class BasicTransformerBlock(nn.Module): + # ATTENTION_MODES = { + # "softmax": CrossAttention, # vanilla attention + # "softmax-xformers": MemoryEfficientCrossAttention + # } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + # attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + # assert attn_mode in self.ATTENTION_MODES + # attn_cls = CrossAttention + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward_(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class UpsampleSR600(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + # TODO: to match input_blocks, remove elements of two sides + x = x[..., 1:-1, :] + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) + # self.temopral_conv_2 = TemporalConvBlock(self.out_channels, self.out_channels, dropout=0.1, use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + # h = self.temopral_conv_2(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d(x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True, + mode='none', dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, d).permute(0, 2, 3, 1).chunk(2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class TemporalAttentionBlock(nn.Module): + def __init__( + self, + dim, + heads = 4, + dim_head = 32, + rotary_emb = None, + use_image_dataset = False, + use_sim_mask = False + ): + super().__init__() + # consider num_heads first, as pos_bias needs fixed num_heads + # heads = dim // dim_head if dim_head else heads + dim_head = dim // heads + assert heads * dim_head == dim + self.use_image_dataset = use_image_dataset + self.use_sim_mask = use_sim_mask + + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.norm = nn.GroupNorm(32, dim) + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3)#, bias = False) + self.to_out = nn.Linear(hidden_dim, dim)#, bias = False) + + # nn.init.zeros_(self.to_out.weight) + # nn.init.zeros_(self.to_out.bias) + + def forward( + self, + x, + pos_bias = None, + focus_present_mask = None, + video_mask = None + ): + + identity = x + n, height, device = x.shape[2], x.shape[-2], x.device + + x = self.norm(x) + x = rearrange(x, 'b c f h w -> b (h w) f c') + + qkv = self.to_qkv(x).chunk(3, dim = -1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values (v=qkv[-1]) through to the output + values = qkv[-1] + out = self.to_out(values) + out = rearrange(out, 'b (h w) f c -> b c f h w', h = height) + + return out + identity + + # split out heads + # q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h = self.heads) + # shape [b (hw) h n c/h], n=f + q= rearrange(qkv[0], '... n (h d) -> ... h n d', h = self.heads) + k= rearrange(qkv[1], '... n (h d) -> ... h n d', h = self.heads) + v= rearrange(qkv[2], '... n (h d) -> ... h n d', h = self.heads) + + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + # shape [b (hw) h n n], n=f + sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + # print(sim.shape,pos_bias.shape) + sim = sim + pos_bias + + if (focus_present_mask is None and video_mask is not None): + #video_mask: [B, n] + mask = video_mask[:, None, :] * video_mask[:, :, None] # [b,n,n] + mask = mask.unsqueeze(1).unsqueeze(1) #[b,1,1,n,n] + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + elif exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones((n, n), device = device, dtype = torch.bool) + attend_self_mask = torch.eye(n, device = device, dtype = torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + if self.use_sim_mask: + sim_mask = torch.tril(torch.ones((n, n), device = device, dtype = torch.bool), diagonal=0) + sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) + + # numerical stability + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + attn = sim.softmax(dim = -1) + + # aggregate values + + out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + out = self.to_out(out) + + out = rearrange(out, 'b (h w) f c -> b c f h w', h = height) + + if self.use_image_dataset: + out = identity + 0*out + else: + out = identity + out + return out + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class TemporalTransformerWithAdapter(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False, + adapter_list=[], adapter_position_list=['parallel', 'parallel', 'parallel'], + adapter_hidden_dim=None, adapter_condition_dim=None): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlockWithAdapter(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint, adapter_list=adapter_list, adapter_position_list=adapter_position_list, + adapter_hidden_dim=adapter_hidden_dim, adapter_condition_dim=adapter_condition_dim) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None, adapter_condition=None, adapter_condition_lam=1): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if adapter_condition is not None: + b_cond, f_cond, c_cond = adapter_condition.shape + adapter_condition = adapter_condition.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1, 1) + adapter_condition = adapter_condition.reshape(b_cond*h*w, f_cond, c_cond) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x, adapter_condition=adapter_condition, adapter_condition_lam=adapter_condition_lam) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim = -1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + + dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class PreNormattention(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + x + +class TransformerV2(nn.Module): + def __init__(self, heads=8, dim=2048, dim_head_k=256, dim_head_v=256, dropout_atte = 0.05, mlp_dim=2048, dropout_ffn = 0.05, depth=1): + super().__init__() + self.layers = nn.ModuleList([]) + self.depth = depth + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNormattention(dim, Attention(dim, heads = heads, dim_head = dim_head_k, dropout = dropout_atte)), + FeedForward(dim, mlp_dim, dropout = dropout_ffn), + ])) + def forward(self, x): + # if self.depth + for attn, ff in self.layers[:1]: + x = attn(x) + x = ff(x) + x + if self.depth > 1: + for attn, ff in self.layers[1:]: + x = attn(x) + x = ff(x) + x + return x + +class TemporalTransformer_attemask(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True, only_self_att=True, multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + self.use_adaptor = False + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + if self.use_adaptor: + self.adaptor_in = nn.Linear(frames, frames) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock_attemask(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv1d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + if self.use_adaptor: + self.adaptor_out = nn.Linear(frames, frames) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + # [16384, 16, 320] + if self.use_linear: + x = rearrange(x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + # context[i] = repeat(context[i], '(b f) l con -> b (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + context[i] = rearrange(context[i], '(b f) l con -> b f l con', f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat(context[i][j], 'f l con -> (f r) l con', r=(h*w)//self.frames, f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + # x = rearrange(x, 'bhw f c -> bhw c f').contiguous() + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange(x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + +class TemporalAttentionMultiBlock(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None, + use_image_dataset=False, + use_sim_mask=False, + temporal_attn_times=1, + ): + super().__init__() + self.att_layers = nn.ModuleList( + [TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, use_image_dataset, use_sim_mask) + for _ in range(temporal_attn_times)] + ) + + def forward( + self, + x, + pos_bias = None, + focus_present_mask = None, + video_mask = None + ): + for layer in self.att_layers: + x = layer(x, pos_bias, focus_present_mask, video_mask) + return x + + +class InitTemporalConvBlock(nn.Module): + + def __init__(self, in_dim, out_dim=None, dropout=0.0,use_image_dataset=False): + super(InitTemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim#int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv1[-1].weight) + # nn.init.zeros_(self.conv1[-1].bias) + nn.init.zeros_(self.conv[-1].weight) + nn.init.zeros_(self.conv[-1].bias) + + def forward(self, x): + identity = x + x = self.conv(x) + if self.use_image_dataset: + x = identity + 0*x + else: + x = identity + x + return x + +class TemporalConvBlock(nn.Module): + + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset= False): + super(TemporalConvBlock, self).__init__() + if out_dim is None: + out_dim = in_dim#int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv1[-1].weight) + # nn.init.zeros_(self.conv1[-1].bias) + nn.init.zeros_(self.conv2[-1].weight) + nn.init.zeros_(self.conv2[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + if self.use_image_dataset: + x = identity + 0*x + else: + x = identity + x + return x + +class TemporalConvBlock_v2(nn.Module): + def __init__(self, in_dim, out_dim=None, dropout=0.0, use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding = (1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +class DropPath(nn.Module): + r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. + """ + def __init__(self, p): + super(DropPath, self).__init__() + self.p = p + + def forward(self, *args, zero=None, keep=None): + if not self.training: + return args[0] if len(args) == 1 else args + + # params + x = args[0] + b = x.size(0) + n = (torch.rand(b) < self.p).sum() + + # non-zero and non-keep mask + mask = x.new_ones(b, dtype=torch.bool) + if keep is not None: + mask[keep] = False + if zero is not None: + mask[zero] = False + + # drop-path index + index = torch.where(mask)[0] + index = index[torch.randperm(len(index))[:n]] + if zero is not None: + index = torch.cat([index, torch.where(zero)[0]], dim=0) + + # drop-path multiplier + multiplier = x.new_ones(b) + multiplier[index] = 0.0 + output = tuple(u * self.broadcast(multiplier, u) for u in args) + return output[0] if len(args) == 1 else output + + def broadcast(self, src, dst): + assert src.size(0) == dst.size(0) + shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) + return src.view(shape) + + +