From cc646a3247016a621a826c4aea13c8466d9b8496 Mon Sep 17 00:00:00 2001 From: Pbihao <1435343052@qq.com> Date: Wed, 21 Aug 2024 18:26:00 +0800 Subject: [PATCH] add sd1.5 training --- ControlNeXt-SD1.5-Training/README.md | 60 + .../examples/conditioning_image_1.png | Bin 0 -> 26716 bytes .../examples/conditioning_image_2.png | Bin 0 -> 7218 bytes .../models/controlnext.py | 136 ++ ControlNeXt-SD1.5-Training/models/unet.py | 1314 +++++++++++++++++ .../pipeline/pipeline_controlnext.py | 1020 +++++++++++++ ControlNeXt-SD1.5-Training/run_controlnext.py | 350 +++++ ControlNeXt-SD1.5-Training/scripts.sh | 39 + .../train_controlnext.py | 1164 +++++++++++++++ README.md | 2 +- 10 files changed, 4084 insertions(+), 1 deletion(-) create mode 100644 ControlNeXt-SD1.5-Training/README.md create mode 100644 ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png create mode 100644 ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png create mode 100644 ControlNeXt-SD1.5-Training/models/controlnext.py create mode 100644 ControlNeXt-SD1.5-Training/models/unet.py create mode 100644 ControlNeXt-SD1.5-Training/pipeline/pipeline_controlnext.py create mode 100644 ControlNeXt-SD1.5-Training/run_controlnext.py create mode 100644 ControlNeXt-SD1.5-Training/scripts.sh create mode 100644 ControlNeXt-SD1.5-Training/train_controlnext.py diff --git a/ControlNeXt-SD1.5-Training/README.md b/ControlNeXt-SD1.5-Training/README.md new file mode 100644 index 0000000..6b5302a --- /dev/null +++ b/ControlNeXt-SD1.5-Training/README.md @@ -0,0 +1,60 @@ +# 🌀 ControlNeXt-SD1.5 + + +This is the training script for our ControlNeXt model, based on Stable Diffusion 1.5. + +Our training and inference code has undergone some updates compared to the original version. Please refer to this version as the standard. + +We provide an example using an open dataset, where our method achieves convergence in just a thousand training steps. + +## Train + + +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 1234 train_controlnext.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --output_dir="checkpoints" \ + --dataset_name=fusing/fill50k \ + --resolution=512 \ + --learning_rate=1e-5 \ + --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --checkpoints_total_limit 3 \ + --checkpointing_steps 400 \ + --validation_steps 400 \ + --num_train_epochs 4 \ + --train_batch_size=6 \ + --controlnext_scale 0.35 \ + --save_load_weights_increaments +``` + +> --controlnext_scale: Set between [0, 1]; controls the strength of ControlNeXt. A larger value indicates stronger control. For tasks requiring dense conditional controls, such as depth, setting it larger (such as 1.) will provide better control. Increasing this number will lead to faster convergence and stronger control, but it can sometimes overly influence the final generation. + + +> --save_load_weights_increments: Choose whether to save the trainable parameters directly or just the weight increments, i.e., $W_{finetune} - W_{pretrained}$. This is useful when adapting to various backbones. + +## Inference + + +``` +python run_controlnext.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --output_dir="test" \ + --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --controlnet_model_name_or_path checkpoints/checkpoint-1400/controlnext.bin \ + --unet_model_name_or_path checkpoints/checkpoint-1200/unet.bin \ + --controlnext_scale 0.35 +``` + +``` +python run_controlnext.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --output_dir="test" \ + --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --controlnet_model_name_or_path checkpoints/checkpoint-800/controlnext.bin \ + --unet_model_name_or_path checkpoints/checkpoint-1200/unet_weight_increasements.bin \ + --controlnext_scale 0.35 \ + --save_load_weights_increaments +``` \ No newline at end of file diff --git a/ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png b/ControlNeXt-SD1.5-Training/examples/conditioning_image_1.png new file mode 100644 index 0000000000000000000000000000000000000000..a6231677ca35a9029326d1695eae9b46ada13789 GIT binary patch literal 26716 zcmeFZc|4Tu`#*fmWri6`#-J!=>?)NdAu@wx&04m?$XXPVkY#32T7)7?_BNHZw8}Cl z%WZ43C5(vdJA*OvJLkTC-`DqZf1dxJf1c+rHFI9)b)Lt0Ebrqu-sjw~vM>_h-OdX^ zkbsHtQ5y(?!9QUL%?|cnZ`~@k5x#a%7XOOMjaZ7o51hLciiTu$qkX$X3W4D#w5l?-Lca-cyfOj&gW7{ zXa9o7Kzviv^(NuM&WZ8Aq!AcnB8*ES+9EBULI*ns?Cik*C`j4QLRN?}nZLiIAh*lr z!JkPFPkH_`n8Hhm8H9Ju3688IErzK(Y`;p!3#l@BbHANO#nqMMU7dr^015q2(oo@%D`gYcjS?20G zSFQZ;?SC;yW=F2B6_!^2a2^vRGYmp3qJo&KMEWa8743h7$`l1|x!&`$GnI ze!o3=x;i2(cyoPuurqOSciGipFzpZLxEeRlJHQB(4S_}2`rR|g!ARY*8dHcCq}C4# z9k#}5a_7h5&4$RjiC-yY25wL54)f}{-e?NGIOSF+ii-8t^SPMoEGoh`u?rEtFwz`z zrE|=XWOU;huV5X+5VgfRkKSx4a{qhz9((JH?b+IIE6D69g{eXqdbm0e?!P=YRPQV` zcbH?L+){b4GP}`yLG%z0sE^Erp;C^S`;`0=;qFcXGYuT z?Zg(nmxMFN#-h+uXnOeSOtBTVE`BiUZAn1njIdkO{I5@)orU(=2?4{}GbUQW6W`V9 zs#Mz|AEbZ3?6ArobpTj6X%J}k%5i4;szHy`64u9Y3*82NIyTCDXtsJ|W+KzSdK@VQ z3*={u?`!lvv?J&cTV~24SGTGnXsrEa#=fgT))4m0PhlysZ`!=I4~-J-eoofL@bte= zaS&sC;2m|hTcaPG$@YG>6_Fwudp@UdAa0nEGe)+~zu?`THSV?O-I3$mUlqRkxqi}y zOO2_h^Fo=Hd6|Ohe!%})p!0B>Ap?CHUg!m`w zqV<+{+tfzRn<{NPEmJ4?nA_I~NnqQ))!CLxSq17FqO&D@%Y?AplRf*CZ>yu+a-%6@ z3bFx4HFdUMjBg{oUE3X|d~ z9p*c!f3C*cy&3Px%n`fA z6q?Np4u};{o3hkh^H@$KGlxl4+LKFY=@g0xQfKUhVYK_1H%0*sr*^uf;`G|tr!yVZ zx=oTD`Os+ApA=Rk(CGg_$phmZ}Afg(eJXw{Emj3T;o3Q^Os4ewK96e0y98uoSxy zcXv!*^C7Zm5n(Js%#wTNFTy*K2q#&Yr95*`XJiI3_M6%($-nZw-P-dg`^Yg6FKN{v zc>0+z!}L|kYxXD(`!(#FoGRNDQ)c)FUnxWsD#8RB5Y_6(KD173ocxp7=zgf9=*?_i zBU@!Ws5q*cvQOytl`A5#nnuU+^^=6ZNh+8|HG1b?yp4jWl6)qt5{dGk*vEWV19GHZ?(7@|JDQi?AkZ+K^cVI~fGKr03{IHZQ7| z_GYX+IzQx+rZy2Yr`1R@xsJnrD$R|!F}wC`TI29(NdP+1<-l3(Jh%hCs;8x|R<7-p z&Hue9u~9aH#ay!pSiOo!tR*Gp4TZ<5lt;3RgsvxVZ(QzPF0|60lQBk*nej)wK7@Bj zH#{cZthU;=9H%8Y19c7@u=f74JnDwyvrzA?PY9qfPRg#?EFeypLFX+`edRvQWtdds zUNw~SdfMk(qVD=1eC8R|9d?D#*!RZ?BN2ZN>|F{vI<>PzBZY$WI!;89c-ri*dLA$P zR&z-S&NzE4@agDKY4!SZ;Q^C3wx!m`+k*uNH=^F{ACvlt+uF3|_-$18WEo!K{kHtr z;N3C9Dv_lv58iTJigA6gzZAQ@$#3UlenZ8s?F=^oqE+`{cXfKYhsKCgV0!#Tt7Iv| z&)tf*UdiyHEK)?WG>6@5BfrVN54pHSLVoUHyi)#oZvO~>0F){LYNDm?^vI?(A3lFv zwtk;@lMu4`P@{J%?E6z^*KDW3k@uqyHAlW|wD4rK1)XniRwC%4g`(8+twgTM@umm! zI(2;;y62&Lk{W&dwXCeP%rwIPhicz zIKak#4QL9#6rrk8NkjgzqJy7QE{w~M9#8f8D(~oSl&-*w-zub}sE)8ANasJnSiJdo zRKSBv?nmNfvzxyi(03zlywO&QT(C*=c@wdz;|$|^qQTy`&P_Z!a8cDj>}Wb3Ek^hK z(cg*|;#I_`h$MVI!v1#@w-^YWJjiwCY^OWVst`j+!pAYv>9#wf1c`D%R&o91zJZ14z z8L)RvxTt=M^ER*4-qZc9*Cv;Qr{^MHj(rL%ZJ)B{Uk_s!<~=tPa?~#2H6Xj+XGk5@ zay*ybTcO;h)OPd^Z@HpHGs3D^aWm?3$RBg<$H&)w*P7LnCQm3>q`)O{sy4PF2h`1j zTDm*F$>FK{&<9JQqeSID3o*NMaJI_&>A}qVNd=N*=i!ZH$I^y@@{+8nh)#pb5{x#@ z|I01!h;uKMyk=pMdOxPfSo>0ylH_Cm2L0Nx<{d(2Rs>PTmf*ejbYwmS5c;}m%lYVUL!st=ql@rt!j(ny1IcpN z&`To$PtRUY4lBz&C7pTncM$dwo7?D+@-Ol?;f`0yj2@gAg}BH|^xW?KE}|jtrg&P0 zt@E&V!zX&gWLNS2f++p^2gnmV)%>F>YmJp1{w2BU?%9>xfBMzlc_2fOYP`~Qx^@bw z$Ul}#FDg2?c=uYwHM=zN6TWQ88nlOBo)9fCC>b0yKF9qo;wXR(aR4^h6<)yKM&T_2 z=f9SScb{?bd}23Tb2f3Jsqwwg5=Ok`N_x?7i++MNh`-|(D6UyJ#=`n&8enGJdcaYa?_VF_lwGL@B=Gk9| z=QJj62uO;IRm=K4f6pB)mq1y1f8R{{FXE?0m578nBO$pxRNXk~`)#ozEP)Ws!u?b- z=S38Q;JXtFK_-?`h6n1hZFu@Cnk{Z~lPI@)bLFUhpKkq1a_;|SM$x}>ROQ0T*RBuZ zdhwqyh&F$(rIacD<}12e6{`)juAec<&G(-de-afN@y4Qr1)tvAw|y}r%Sq7o;ExCX z=jA>h0uG-na2+#krbxeF>i!Yg})ucQDVbQCVPhn^Sc z3Y^dwcB7d|Z>ydZd>)ZD<)4wE=5x%!d-Oz5vv*SayxEjU*I5_j8RwhBdn{ew#yqhc zb8f;xkve~!Uym4El`}ntvyjuWVZRIevTzm{A~wAWPP*w~w^O0QqvDp@$^Ij^se-k+ z&HzECbi2BKTFli``ILUn_0bE=2{tu-#oFctH=U0rxS7IU8>)KR!d1zq&U{^?^x1Mp zoA#Z@#Zi}=UpR)6F~)Qsq*p1lw0bmgw{DCR_g?NpUWV4lGH%&P*S2k>N%>8CS8J%%N|kv8ZXRG@8Qq`7(abV)Xu$j zuf0C)#`5Ti^i>!kyNk80d%5@YeX0D2-i+IhgmM{au8Y~BtM-!c-6DF%`8%|9Zmm8h5RZc;9z5Nx#PPvZZSW3%0&8J{wzn24ReQ;CD>K@NoXi;Gl;L)c2!~ z?ic3<0Ji#qX>~z7-E!qw6rp&(8a@?L>`W$%hm|j%`&ufmq}*+ydhZdohig)EjZA$BkG&(m_lW+fal9B|aXeDcxYd-_mG5i15;H%{z016yRYM zyz|nZnepXrq3gtBVhXQH@3S8Ad-zirr5|~&Ra?3m-D!XG5goCv+5Kd0;QM1J;dc3S z#n>Xi&7=y>`F}ru#r>rR<>q@rIrE^NM7|h}vE>uRS9zoPiq|0%w(PF+&$T^qHI#I-t|J_)*RSjwevkFA(5abG8VNI--|podA2L$EK(js53yAq z$?eg#w>pyb)pI_lFQ~bfj$cu$9|~v5+u-j&X1)%SGSuVf6KcA%hACr88J=e;zHa+y zyH}JO3GFQxYvos$#lSiv;ADG<7~6Z7$~#u+_qv62&)l2YHm0fnDJui9x36$P+rpxC zpkL@bvwck1JcdZnvnz5(N2E+_tj>naingE~_#CzqRV@GdwZ4BoA@GRE>{Ob{ZKrh+ zP`=8gjfhgto%3Eja(Sp_ijdJUM>@vveiMC$)|!L3bvJ`hzTG`N5+ggOL}l>Ion?8Y zCe9ViX_$uE@GgqaZcm`D9-Qed8?51uPm~ZoHq`_B(T0MetYRNa*+skM z4io~!(CFk)7yPwGruFuVgLe;MEJ2k^iq^n=^Z~<8}zqHI`k`$w> z4J2nVTVk7cT1`6dL(IPzI@0pwn3=mrq-@yc=A9ZMRDmBct zJ(ob9+mbKl!S#SgaMz_S_4d*(4H5qcj03l_Ej_C5>jx9MB1d=dQ!2}9J509Q?rHK4 zuF9Qd3^DwE2xBE~J)qG(1)ni_n)yTmb-&S+v>SPM`sJR+KxC0rGF9Nk zm+?5*eo{>x#)udFMZD|ZDOcfVeR^=QawAxw-h=)q$7%dY*ALx}uju#^2&t%sjZ$9h z-q|K*{?gYq;nSl;M_joKm+4oQYSPEEyPK)0U8F7}*j9$NVd4Rq5}6o>=kAz*v$7Ju*2QB>%aG#LA{R3 zrV&b)b;4bT5_7lkR{gaha;n?gNg9X{%ELWrPk)9?In6$+d=qWWkhJpQ3r0GEk%=u*=scUD-c2Z=XqWW;1f- zVAC$vTJEC5_NP+D4U!!%wAWjMWUbVog)Qrg_667`vIQ+TgN z_ZVA==jfeM(@52jml70QQdz&l%K8Tpi|nf%3e@*PVy6*LZ$wN~92u<(MRkjv;fCLA z|0GRQDf!(|Y_?X3V-R5V$Pj{~RYF~F+V?u^t_{|a4CI{FFx;3SLvxX*PGk5>pb??+ zH4BkwPAoKl4o=^61&N{*{!=Jt4&ohktOFxx2>~Jct*xr|u z3O8d6VVveJYON^c`7`g)L@jK&WGW(S{ntk$(dW&xp(huVfaBSN4ESSU>|WwMnxICy zOO}fakPZi@qR|p;hrmc= zgQ=jz*wlos1h2%=oMQd8>{!Xlxz|sT5>+Zk41ldCO=U3~V5}v(Vl$vk;(9Pv4xtA^ zfD=tjNa4&xqr=_n(a$=&$L`Rw(BKe{G89C8=pTFI*vWt57*Q^v+#GfgRs;&EU#KkuFBYrZ<6VzLlYCQMb8?^wBiIF{9lPFQ)j3#?n|2Be=g5C4{wO1z36B`aN1L`39d6WwYw@juy);e?H*$E`1@orVR4Qz=ZB# z5Auy2^pv>V@cUDJRwG}FS$4!`WuC~Zpj5uQUxWOh0aq9&mPkA_cjh$dY*L_%Jh_sU z(JwP6Zr@{k9fQpMWTI*5vI^JUPkP~1c|Xh>W1dU15P`ZD!|teTE-uJ*Ds(n76KRcM zaoe{2#><7(6(lw*aKIRz_Ks%06P4HgOK1rl?6uCE+sPi5kdTyacbggtinztd_R!by zCC5CF${F-ECAD+b`RiF?UgN9W4BZ53L|Q;jEL&wU;4%puMU@BDs+mUoBfY!=qHU*+3o_x3{nm1BD~?!@)m zz~_api((51@>mBQO;Ou2yw_W+6hs7+?)Sa7S~&jAM#}5KEN}qCUrgbO z~H$Q@0M&F)>Arx zs(vu+##B)<@{L@X9YVRCeClHksFX{j(JFKh=rviKcKFKF#7zH({^I9z8kV%c@$iAK zn4*&q1Tw3yAE?X8v`d(x7U3k%=qWiXg?;KjnsgA~&b#OONx;6M&@i*{tk~=x1coU| z9s7-;=_UP!>u&Gv)xc{f+*hEv$Wq6I9SWg<=Yg!tQo!Q*;mHLg-9FYG&q6Q|a|BZ& z0r^=6;r=RTHR~|-wg{N8U?5&M7ny28O4dPq*9lLE)vtrr26YfEQ#-oOh5z%?w{cS( z=Qu3vkfwnDf;cKIHT7btYWTKP=gM__WbEw-%I~YY2+^hiivGF&WygBN1>$XtXH*15m?>pLHOq8R^iWn zt1o`}qA&3h>)b=Bb5G=wC9ejni$D>OVx<)`pb{8C60Ha6@ZjT;wBxW^@=76-QT=;9RrC{Ho#}19?~*+qa`0#|7FBXKXfH37y>w z`WO)G06DCBf7LhHeo>mM3_Zd&4|sq!wuJWy|51W`zO*hlTAH;->H9t~bhnThNF5xR zD$!nl&EHW0kmfKBZz3)%R+P3|{6i0c5{Y0)qNG|dFSforMSpjGtiMqq;8z=HKW$GY z?@~Bm9I4}?8cP_Qm-g%%oG+vcbK0X0#h@VZt|DyFOC!+a0@U!_&$8VHk(_-L<%25g zeq}+YVUm9KR^o%TAT@%1Zry5nj}>(0n6^^oli<8AV8uO>)JL1e`O!Bdo(hZ0Yl#nE z=O^x_Yrf?ha6J`PFfJ5&3_*WPQWJ2I*Y~w~8FegOz?53f9#?;4DND7(ct~Fy|N6Kc zraZj(2I)19y9K%1T~;dYOL_ZO>lHHyR&cp8{q?Jeo#D zK|_h7=0hQM9-mW_1+dAEFa3`H8jnO{jPuUI;+g4X%7;GXKD(yo?Kd1QuNrvnSfHj7 z{QPX5UG;9*SqwkG?ln#R;XPXVtms)a$y2V(7>;)jz8E8&q#`)!mBp1oPuhz&H9RT3 z71LmhX+X-sb$-{*@)vC9J#hN~adFcC!lfUf#%_EPDm$!UtJR<*0xNN8^8Vtx z7Xg=-R*ELk_|RFQPN%>7dH?z!E`X;>-26){!E98?72Zey8}^1JX40<;SQIxV z>HX3ynH|3E_fb8@LouuLwP9w~)DwBkxw4pp4!^yXphv8>1Sg+*pJpe-_9;=&+~<>` zDeYQ}>RNubRi5=|%dtPh{7pzj7UYP#JxJ5HauC2U6Nm-S>8dNi%R_OJyGB~t&14R6 z+C4!EL{+t}%@+ZxlZ{_y%lrzyC%aqoVp)n9CP?yaFjvjmi^O}VbX%Zo6Bt`QFu2s$K> z32V-SJgOk=fa{E^r(E1+RbM&TV*mqEkg3M_d}{Nq>!q~%I6-Zr+6j%M-;lhKB9BFU zYR`@zh2L%@D_H6pGVDvcN4{JTj8g|h1}EvHDGKu#29LZC@M{qSiZdx)5(srvpjm6t z~6}U8fs9X4k6Bkx#5+Ya`>34Ufo5y zi{i?Y_h5|^g{8WlE0%Gi8rj-fQ2f>G@Wpi|@I$xI_z`R9I*uj?Et2)~-BC^LPb1_0 zRG?m@$SE7aJedfosc|~j;WX3$@7W9FXC zrzuCG&H<;jG4sLFB*f939mxGCX7Z2f>(PZUpfS$4r!QFQ8qgOAeWhzHSq$w8MLn-y z+MoN$*95dy$ua*l(Y~TT!n9UNU&X(Osd^kJoj2#`ytOUNY}hUG z2WV=4wm5al7MQsa{uCth6SFBJ2=NlL;57+bLZci8n?{;VA+r{9_ak~9{@umtY8$Hu z>*&LglJQ9LK6MQu6I!S`8~s*pX3dtzh%elb3cyT_5GWy4gfBv;)!iM zpF?;*@+RpZn8(gGeXVWZMdLFeAwlO)hPavv)G7j9rb9EA%uiHJYNM9%lSJtu7tvOG zS@O2{8?j(IRWRM|w+S8c2u!?bp8UqMv*dXcY3#cu*UNZ{AVM5o%bdQ-W;7BmzlXu_ zJz)5=H}M^Epb!`hsgb3TJ$p2H5p3Hz0aAUQhhV&Io+GUA<&y5l?O+X;g*Yx_2exU2 zhThcNa#vn#FBHW!B1GdQYLQvY3E%^EobEJ`Ro(@ro$m?v2EzS?R30~?t9?-5rSwMD zGh&{D3;^`!b&uX&Z< z3Qs{ElLkkN3-t$~xiyC9n%*dm^@Z#ZI;`Ceh=~_Y(CC-yeT2un#Fk?Jeo#N!hI={` zi4Kh=SWq#1BOXdz5pnuSCu$>HIz-EnPXKX=bhG%Enh6n3##5Rl^I3oe65LVGPG6fH zsOpP!(ct;D!Qg*q{aN;YN1THJ;5qH*khDkwh|7E)*DUqKjgLa((QBEHp&~9>_WRKJ zUpN0i4>$t-m;5P?Q8oxr)q5Qo%i7#nwQpTY7Eu(uKSESQ%!n$VTMP@@_kU?pjSAQf z7GS=7$lO$$wA5n^GgbkWBc=&bLMeg*k$E;{E9KeSqz(metj8!Bx@uP>*e_-9xhmz2 zhv(`A2@;@&7mLn09Q=G^i=k8t_rsI49cFYCM|pQ~TxJ)Rh#;ohcd#PtR?(zP`nLkh z5Yrep%H_gm;0)5Q@U2Av<$8*g{&Qp1@*t7Wc_3FmdwxjNB}{}rE0NV+xdv~mi1@UwfIFV zoDxp#7H|v>BUPn9C(&?5r&?c43^_I_iuf0Bsv~t`B1zgPw zwCo;yO|Lk%5c)k(70#1D0VI^^2&B#2PTFqe1giUyn$`1}fFIX4?LQ5qvJXxAp}+(A zt_98*VDNkJENDuc0#i=n@4%0+Z;VNPz|`E|QdR*|#p^i!S+vS@c}b!gNuO&nwN^h0diK+;ZT9SCSN7{G?9EksC7wJH`nG`4){Omv(BnOPjUNiVO-SLe z6ZFRW?-ED@CJ(7R=`9ZBTbg7MwnP_g4$IJra&tY_} zr37iuYzwwr#>9UOXgw8vANfW;4=z*JKLGUOl?DD;lnUDx%Zh0DLEuH?Qx$n1vE+DZ zBDCVw^t0bxQKaqDbX(xn)MS&e#9dYGe-=Yh5e|zTcC|_vtP~}fCQVfVJdk*k!Eg|E zZYcl+B2Ku8@MX*LzWmowZ9SfMVa?2mw1A*%6Q7fcp<5T$o9!)7q=n-nT%cqVBuB|p zm~_3&^AVR+Oi30GC<$JCKV}K!#yamK0MzIPjy5mILzIT*5HzOcb^IN~9!U87t82T= zPKGVwsBT2>Zvh(NKd?TDCrKchx8;e_L(Px)YJjgV*MBJlGQer3(0a?gKZp~2 z*&gVRLP>*)NGE9OpgW*_iX~UC0vxsVID|-~7kq@1iYO5cyAVpKBo*bWt<;DM&*_FC zw@prf4ddGA24FTbdE#gdAd&5KKs>;f8Q`N3N|-z)(oUKY?L@8FjKH^NnY^@1Kk5{K z25&tEAv<6WE zp|ZTF^#yKSca+E5;!UFsVFTdJ#O{}M1EsnzM>$I<4`nrMPg<(+@!FZ@Lw2U%qi;2ZLrt~hhz8sOsob#HH>Z+FSgGQ1?D7L`I35uvz}H(!@;A)!JS zMs7^WDr-piTn=bPpENZe{S@Li?BZ1R+Jud_5{a_=l*#U> z>s)v{DE3G@MI@Dr5pRsh;@^G}F4YcH9_GKwt4%SL7D)V#^9hpZ=^!8U4`i z0s%cim1_V2JIPXLiF1kWk}mAf)x|Iy#IRWC`qKD?QOhCilsFF|zdPB%yK;b>=i;kk zfSi@7Bd})3%+CoxOYCyqmJ{&mg9Zdg1dBl((j)fLs}&Gmh{?#i9g^iy zZ@4UWK?_`XhOay^!EyiuFD23c67p3jKT?USY$h?!=$-dUhZj-7GG{73)4LAvZ;w+`5mH~I!Eient%r`i}^Ti5`wIK-l#cXKk$ zxze5}eb!xEx}&Zdxs_>xj(zBwg~(VgSE3{mmxZxUKm`iKCFIG{-5`+XePk&-A(USc zp-y&}hD7JSjuL@m$(Iq|nf|N2|FN%c$cYABlcgA3(oq8F1k2D~ctL&@AEt^I_CVnV z)3<@ucA-BO5@6kBh4Bs!b9EnA1Z!=qkgz#!$#V*+wxJ;q;CsX~YyCMCU*G=$oUZ`3 z$uIXg4n>6V670F9F%2ndusu;42|9%cNk|x>{@LVM(F0!E^|`9oAOuCgLUZZM$>Pf5 zY%7pW(8)!JPVu?ow*wbX(SL^|TKhhQmG?yXgYuwC)4WYT z$fn1)yt<^@<*x$-JN#m`9UPVn63z7yM>8kfB-$Vw2~aK+!DY<5RLdiIsf8d)2Bp-f?WcB%MQT7qde9!-P16@d|f zHJSYt!G(Qrdi1blH5`Z&Nkv`*taW8s=TZP~Iq<2KIl@I#z&j!YP_O!?Vt~HsDuxSk zjHyG)$`*D<7^6u7f+9KpIxP|x+$#ggQ6xdB2M)l>?x9UnG$?XEGg~s#%u(TC^&eSqO`0bQGltTN`l9#slCu~35QNYbWUgJPY5yw(a zwTEN?$9>w!DF-3Tb;kd{pA|ruZ^p9b zW5#RR*Iz?Dk=MYwln=UnE0S9Z~;=|SI%>!d;pI{(Qj)EI2;LE@c6V& zc#mMTeT6Kj46eCI!fvco%r(r9sy6@}mI)mBhwh4^`mNo1&;IhzD7%{weT0`b{;j}X zP*QhF1d(7e&aU|+6MC*5EN=USh(e&?@s@M|)18IbuB%)$NpPTwXa_E1p?Fnd*!1h} zily4E7(W2jNjt9O{C%W+aM8TcTrykI6Vxtv_Ij~IVs%X+UVYvcZYAL|57HglAu#!y z{IU0cAGB=JCS-+l_~wnF7vx4wNJCawD`>%4ebfH8TZ_@$(lO3eiXU@!r$7$OP1S^5 zYND!IpzH5*muTs0y+U(c#iw{IC?-2wlZ^t?kL`=NsRg2Dx_jYTD+t`uX!=sz?p#Oq z@0K+Cze&Y_z?%0raKs^WtgFS!l~*QF+IZqzMIN zS;OWBR*k_Q0asI0-cZBa39y=1Al#T29!X^#+yj~UyG??uqV*cEMZf#!gan_%W!Jr4 z;lMB)7^ajjI%dlwN?Ch#&l1#_m9oIFx*efsI&Bf@4lxW(;7>jPDgHRWp>nPSR^|?# z<@i@R7A)rStwAiLxk0w)>S0REuN^2dM~2aJL%EF@|0vHVx8?lsPp5|9p2 z|MMXt-~SsU@{H)MIDZ?PvJs={@D>&4EK3{|rX1_D_HU@f1rp@Ym7RQFM_H@F7junAC3At#gEU~`Ny65xBD zlScqme`2UUE~vPo;|Kzrjj#_=QV+)Yq!EiCEK&Sj`09nXNZ>Y-*<@xTxroJDt&xQ4 z;5m?lwyV-{#1ETZ+tlq5Jt)IcDb8E@U~qMyeFE53c90~>h~UKr;gO3rh>l9(sxCJs zak$ETxTOe=d>Un~+SZFEW&TJKqACvujhcrTif~D=<^-Q!U3Az_v@SEIzXM8;VAj6K z2_sV!s65JRNoZ6Pwa^P{8 zBC}X8kx5vB?px#ka-ZPR(XDtBU`^)h1HgcN?t^UXTN~lZ?~~9mf0#pI_PGZ>D%?l7 zeWegGf4DP(7odcHE<&xxJ}-im^Bd-@ffc?1Hu0)+z_zfMm-gN1;~0Qe_kmCQY;J4$ z^1&V`*>9_)DVU53CgUsC=|B7&GJktMLKTQ?2>G(jzaOQ1QeFoC0pPKJVb2{{B8K!$ z^76V&pQIfJl$5g8yK7kcQFGjBi&%`GIas>;iCuq3y0^$Crcnv|?;9K}0-L#5%S}{- zw_^A?`!lX^0sG^Wsz35jm@L0VGq+($b9QOfij%iE@Dl$SRSw}nVQ!I-iEtChD?0#b zvm9lW|I?%UqVa!vr)PoQ^ta{;P#kzQ!&m&?%H#o<&AB9Sv#9&;TMJT%2art^bH34H z&b2aho^*>YaA4RCr-=hLA~UoAQaU31L56nt?-keaP--~#6tj;S4FBBeU3v*1YABpz z{nR^X(u~oche(~Ug%SLWWG)0yZ~MQ)WDf<>-(^UN9%qA4FntkGxjs>N4ByE2Icc$i53aY|^EwH=JY#A+hypmj-X9-I1$P>r{L>ybl|Ew6b1BCXjZ@|uSrh^gNnr{2LE%ob(OU-TMGJ= z-73%KIa$D;1h>_iSa!ZxI`_M^88uDle^R47Pk0RE4aO76a)>&1HFKe*eJZhD@;4{0 zDj|0(as{W+QpKHC-8gAaePbWwr2DE@~3UCmnwf9_^hO}SNDGtwY-z?{D}AE#H7Z> zT$|$?M^SV2@dRB+5p@&^l9T^cc+hFWY)p3iv4*VB0XP73-z3hB#nl{#zh8RWdkT1p zm1{qCfRbUc5NWhyods$Rc8@k6(-|Sl6P+{VDIQR31)*kM+?YY~{aG{{1@<--R~`Z! z#C8RWI;59KGTKd_omlIbBDlruM;sh*TsaXbN;yJZIRt`eII?NB!*5p{;VVLB?#71q zie+kzB9}xQOmre8JpPW zMii*TD1$AN2&k{(ipbEI~k_X zx*sEXVT`~ITOKdSz08&m!kdAPUAj#)`pt3Crw0OzA&;~^BX=HboIG=KUfAMjl&KE@UM#m}!H4;?BwMJfhR5T<8 z1ksQ;a|8(D30ig~*YQYuln|>|Nl`g{%{!6u1aj=P2HFxI-Jc1j-G?;}m3%_>T4s*& z1Q9;KUYohSQK{%Ux%QZS0WTAVIZ1EahY;X;w-WJ_8?X1ot_uCt7{W~IObm{UvSCaC zY@h88ob*38xhhvVBlYm77*AB!a(A*FHWcy&CzxZ$~W68#=Y5dd!TZE{k;2^1dgFwLr4Fv9zC-R|dD&kY|3r?*wR?Tnerg_B(PjcA*eFG!8>D@qESL z?pH2-9biqDH%@VAyJoQF(u-QbtG!36p`z(VW=91fAsX~zojEy{Et*m7B={x*X_-cQ z0S-ulFH-^7KZjQ4u`1Jgp!`Y%1jw*G-DpZIyR!+QWlB;40mS{zYBi9hG}>dDQ1eWr z92Ce~m#nvx!Ij`rr}H04@-W`#zqi~Tnb9fJ6UcN=BOoJ+pb?buLT)Q_ZHu?wLxfk~ z(yh<{Q0zVU@E*to8)PZO#-UlzP;j1bjX<0m+KA?+yKb{-Y4|rhpKKG z=(>l&;MOQZlO;9hETG9AEFu+4I1B&XD}?d&WlwX`NcH^Pm(4D=1#)%4oK{_KJY@tK z%Bk+J&h+<<9NHyA{SPM-B>lN#(6x%?-(S(W8!3{-cHnNZpaeaC*ztqYyIcA3b34$W z9k?bZaNjk3d437-wDof`bhnTlxYIWkmIQqfSE~S8`F8`6sQz1qF2s5P%@Ne1q#;IM z23z)iYg`ZKmfN*F9di&e5~09NkmBn+)l}O6w{F1pS{Wwn9g-AHb|8NyVm`MQ?$oYG?{(fE4VK8CQh=9oh+iw#;99M(hX5>W@SZ-b!kGmUvf| zXboc-1m>KPE_$IgOXsiKt^3zX}MS}O;#+;BkUP*Wo9!u~d zN0m!Z&d}nk53^o0NGcF7Gz2?rrx)xI*+IC@WBQ-H<%4K!MTkRXRO@=iS~4HyF|~E2 zBj-q8+xqZankUz}t5@-6+~1x)wv?jQ18MZAfftQpPBd0Ipn4E<4YzmnZzsRpM2T!g z+(qQ;c#IK{>>B zwfj7KU0evzzczdnRxWH@B7;KwOt{y2^D5}-ml15Ba(>cHSM9I=xBgNL>hS?eFovsW zaxHYzePt}J;{`Sk%9A}^ap_r$0Ocg*APs}k>1sP-0?TmvAMX*BRls@eLLjy4vdU^4 zV1|-4VuXj(<(kY&l27lr1WT0mS1w-QFzMcWFvs!lFV67Qpj3hNgNMAvi-y(0z0KX* zboO)QB! z@x5-%`@L#Wb!!t>bS)^cC{^wU$_}z+9&Tcz=INffpWBlW$P{!7b(6E6sURYR6%RE*k&&voH37JDy63L{Vnzln^66 zM3%V*xHpGyL082IU6u{i4&3p}t}f)VV;wlndGN~#1-;$(2mL~8Z|T?K@<6eY<6rL) zMJTO|awKh?Ff0KRo>3gR<*IXg30epTb~g{HWZ!I%zkXgh^tt)V&GDse2V73~SB3>8 z2mM238Np;}1$hO7tm;BWa)vG}SwL*iY0yofC~}8xF0O(fE0A|RAZZ1-DVk-_oT}^) zYMW=D*((oq`gx)7Y9QRLRKLZRO?Ls~CXnPi9A4AjYNZ<_ebDY!*kB|MunK{vK2^nL zZDf}j3DMT%?m!z6Vt*IUKwre%C%v=#sk|2p&3d$c{JZTTp65c=20F>zDF8hkMp5Eu zN3T6olA9!@|D~vM$J0GNLiaMNvIc~={RppR#-7dyR{U|tMaw4yw95B)R#pB)<6k$N(6f4aqz55jtBRMVcybCc(O%x z^Z9x#aGYP!9-BGyH%<6+0O=uQdNPycKHr*Jx?4h$P_=@yR~AaE<8RW35*+Mj>Q$q|XR6nr(zTL2oLe(5`1A?6y;ksM zP}$|X!j$U(7o?ifKYjo|adEJjT+LdawCL$hN$=ev{2Q9SFIj}#{C?jI-+k2Tzr5Of*#N9D zuP9uKDKj#kIz7hIbiN-Kq#-~*3Fe!|DBFEHohS;SAZpH3)O&+NYqg0tFqYZBz1sjI z0>kr@HC`Ewd}2cXXm+6LyCJ+FZ0l6#!`SHxy*k&$3_CmV;#*OwB`DFpB2t#G$SXPC z93eNTNTdx5LU9ZplALh_S&7?mc1$OszT0GIP=EGCLN*HT6c8oDNETK0^(?yaf+g9Q zF#-MI^Dw&0VR~Dk0=-6m795j1S?FsK$}zdzw{ zSe2au&z$vbPg4o zXv;fMs-z**_*;;`EbE(!SK+IjEh$)>;?rH80Srx0jLngz+n$@+I~sL&Q}(2?9ps{) z<#lerFWcvGH{Qhr6dO_j*GYb-P|u4Ks)tl=dBb&~pkJuNGMR{pBcSrhzyG=ybHntt zY?IUfY46Jap?cf@%orMLgR*DeiY!T%m>EmSmL)1B`_>0BWXm>Vi)1Nc2xDnMmdQlQ zU@RH3Oe;;v9+hnvyTSL&=Xw5x?{CjPa9-zipL1XLb-lM6z*cluR{PbF=fsBwkbl%y zaNfswC$rwvw%MYefQ#HupNx>2SHn(nezA_mm3ry4kBa`ncuswl!Yd*?M`cx{jq%ii z8u+{J2%lZ?wnb9cE6&@NAl?3Mt`EY$NzQ&Wk~2*;H6K>HO9fkbzT)5}% zZ2_D(R{y$h*QuI^HH%vAlRDG(P0QZq4Tbys^sN(U=LTHQ9SkIU#wgTtiece)1t4>iqE@lvcN+_;K`pXL&fpH2ad@?Id>N8UC zNROhTGAuUJe8awo6Ac?U@X>6Lmw7@0;7e*1Asg4+PSf^A+$x9V;44jTT(92rk?ZDT@+X;zKt$faSrAIC5!GQ|r z^-h)_e(j7g{zH3QBv0WPbunVB3pt>xU=b^LCbefNyfD=74&XrgtdzYNG;c5rQ$^V^NYCdQ+!x%- zJ2%6anve0Z7a{73ga;6^LbZ-s$twU3m=A$IF$KG|*2WEO&THG*UcYIxGcX1WMOTIAq=cO6DXyGK$ev?E)^A-i`vwd!UV5ErfjXpMii*_5#{-_`Sq1X|c#xAxFSR z*cs4(aa$K7h^0J1Z3wSl0dgDjou%p~e#)jH--D*P--85sefG7zg3|^7D@6D?;7Rb4 zg~BIkpr~zIXr$=J<{UW;H4?1%e3lz{t-8Yg`3r6ZP?sZfcxU-91{W#HY^qAnoSTEo z`Yo=V;){T*bGG@(@N zTc>69wY3`PIuAn%B)oD;te_98>WB9E_xm*7Z*}Op7abt3^+mtB;N}k$>}Vgrv?l>d zS}JYz4FYF&SYR$uAE{h1H~wxZBw?DizD8KV-Z$N+Ksi( zHps={6@DhC*0cSdZECCV78CGuG(u4CstI+*29N5*Z+VS_j1Lv?Yy3d6ID(fHMs8YF zC*uupKnV->_P8tWn~MlZmEGr_5t#g0-)8E3lv{dE^P&A#!UO45_!tq`3Zn-9wf6P$Sw??n)9VX~C;y$d8#eUt4BKqC1`MEp+ea_57(o zI0Dfw)(yM&D0`)}XWf&%5@8bNDK*W^|0G^R$%O^ff^0Rz}qn6g&{4ufl21s{uQ&nFY04OFwAn~9CP`8 zg2rUe$#ivFgMA`TcDIK%?HfUx{uf~F$Sy}fbcDJMXnax%0L)D6*G*bk&M`$ z`Z60}EM_?>0rw3Co7i}eyI{}@BC948u=|)Io;#Bw(P02|yc876o;M;>`;i3eetXw` z=#y*bo%xF~q;?oXdCQ!hC_o*D#daoCe2~gJe%r+XaIXy{0#?6&txj7&TSrQVp(d2p zed#NPcnUhZz<3zq2#9z4V&4A}sO%3T&BNAl%T?OMy}tuy0GOQ8uy3;zR}O*i#H6^w z`Hh`bM8Wx%)IxJNXGo$7$RX%P35WyAk)X=nt&027N!vY}nZP_C1)5Uc{rRczK?n>z z)7hd%@lM3!m~snpOp{b9C+|#b4}Z2mI)n=}Qk&NKUDiKPLBjysstf{*Q5tN`fZX%> z%BAWn8WtcnBjYrWYgWAz^i&csg%L7Q(5Mis{C6dvxS_eAq}V3zIoy$&m@QuUy@(Kw znwI);05zc~X*5Fsw-c%p^Y^^8!}D1fiW3EPBa@x7ogitZsz7hr42_Jpn*hLTFG2A3q$=`*1+j{BeLu}7#jU3Y%rzc zPlP=_*crmvF?cO5cO_{~c47Pe=7aW}{>0})dZfGN%jxzuRptpjwPg|csK?OrS3n8J zE=jjw>}0#`AQ~sjsz%71U;?Po&@id5A8b$%!khIQCjLtM6YfX3! zf0A;8Z>62Alk5%kKaIZv*_F2)w;SuyU3}Mt9NXRc`^3Sl{Vh-aQ^?r4nnR zO)0XQ2BGm|PPr@#(*9~mxd#1gGBu5h90K<*W&(u}4YM)^r~1Z+#E-0%yU7$OyPNW$ zc%-$fU>6fq!SVyC9jX44QrQJ@J~AE7k2298qOY)o^Ado4ygXP*%QqxEdyruhzjJAm zZ`Lq#l^osshLld4ckcT2I_g!c&6@y7jx7i^fXOa8*aIgm`|s1QLI?t0^_bIw?;(ji zvlr^R!jRiDIbV_cD_OUqq{SaJf9!2I+tm1j7hDsBV@4f3YYi<5Y+m48NhFJ9=IS%{C4j4=gP!0fSf)8UVU^{gd zZuaEKkxXrA;$gvzfbZ4@b&imMQO^yVPij7EoZHjO-3xc^jH4OxH4xhSUBY6*^;0RsQwn=)k^p$MQoB$a_Fr@{*Pmo{N*;xBHE0mPUt0!vc?Rp}3%ca2^DYP4nd z(*_VEOH+BE@!%s^7;{0b_dweWlV*`Q{?Z^+4C>)-dGCg-&y?q-lee@-4^P_scKP}2 zY~!S%h0jInt4y& zrnl~f6xp-Kk7=Lunr^I$$Rn%4fT$Lft>;v|%bsYd$d*X&Uj(S{+T}-(81{XZlZ05h zfE?R&8)FZsm%+vgrpo@HPTdXmeMX(5HLAZy0V6X}%HPOfmw$&zjeZItEDx{634A-r zTFMT)O@ZIhn_69Ip>BC#-TEC44YsLxRrZ<(51U>g{g#AUujG*9U|z9OuK6qqXJrdL zQ$fwCybCqVt>#%PW)dW4Vv-Nyz7Mrr0ICC!k|WZL>$4Rp-Fj?vlFNwWL=Bkd_CP9* z=EeuW@^hz@LsW%8P~>H3ooDqYnHK!Wm@rl~<70b{2KNim_aUxy-ydTE6J0BN5i+~} z3-jBpSwPi~W<3{`K=5%hNgcUCnd-h`vTq9wzGdO=HtmeCF*54aTW9>7-n=Iul84Fz z!KO`s>eauz@ZK!`IX;w!GgcYkM%P_m56td$T4d+8JDn155^@0|&_oh|>>Z3SHjT}J zR7y0!IOPr|Yiv>^3bBX#L6{{{`c>jue7+e|Wl7Ez^Laf|UeZJ7n zyzbWHCa6%g@pY4353q*aaKq01#s|O#6RHP6P1pHT;X~j3I`Jo7 z;VeW?i=sTwoWoUUIJ_`+&1N8pFP_*AN)LuL>sTalrkq(U50nySoZZWJAT>$|RSB?T zpAsV1M`o^aMr3^Den)no1#W8kkE>zdsepJ=`wmk(HMY%<+^&L;(%MQ*-6b}=)Q z!;52BpKX5mpFElvf^$?n=cZxG@-*(-tr6nvV}aC!G=k6Kv8TC+b7e@`33VWP%X_y! zyoIKzMPw$uQ{+nVUwge=dynwgJMg={q!M;p;3^4XYrPa{BzQ0&od&0tfVIuJ5(9h| zOd+Mx(gaad&L*kQKIf33;+4979{hNi_!qHzzbEv^g?(z2vWWgWp^HT2inU*`+dER7 z1_dV>11}9e>UbVUnTNG3yu9{x*8=D$(yxHX*V%8KV(V#sO!suzNiUON;n%~bDeBt}JJoi1lHoWA~B!F|@ zL$J*ON^=F5xlnF7y z+>UwRe}s^^u-tpfPn2J9{+Xw2P^!Oj`OhB*B+94DKp!f3e0YL+BJH3Jf?{m4$wp^p z54JeydiOeiwnVhh99200`h%k!HlstVo3YJuEIR@HD)X|D&o?&93{#+|6+OUQ>v;Rw zd+%&nv@&$-jr@~Go;&``mm(j6rayN!`k(GtDH$dL{k|TY^{EWk4nKaRBx(x->_vONlkL?sYJ>Ys zJrBstJl7FavU+I-o=C6oF6U}8a%?Z)_(?N|9*Mc|$%m8-9Bu@Rk99f~nw>%>Z z6Lu}htWn}i;5~jZT50mUW54N9#S-B1~uX25A95GTq6j>FL+ey_MMBY@bWgV zx+QgE&(g1d;7Q+;{Mu(PI`a_S%eX~`?pA4bSr(%?h_qW_B}NHi>8dhWQdZu^uqf!T z-pw))Jtqw>2CFdSRZWOi2 zIjgFoi_(pSylYB|Ves6Pq*=%1JKygT95VaG3?DDX0A_`s-WLYNgZ7<;z1a?qJ|S*VNgcIA0**{%<#_syh~lkcl> zI>UJ+(&1DDbO6K;$$cw%<5A+iBvc~$o3FUTR}(%N)GkA?L#iJzaa?)3M$&GOe|}~E zmA50rkK(MwFX!AcF5GNWWdGv%K0|_WfBN+Axv=C44O6OT5|Lt)o9a zgYEq#T`^g24jU)A>Xf`cWhbe*Xz}TI!zCY4*}Wm%eO~d;UBLeE1z72TZ!#tMX3~?F z)W6p#%=bQ#9ZFH-+0q1is`Vtmi9s93F2p6AxbUqbo3uRo`pzSY`2(5)gbiD9)UB)M z=LBJ)((Jm&gS6>m&lTID@xD6u$b&ebG+OR(4uJ<#LwW!a-fa!-A!zs?6Qt`^wIan*F|e zxV_*?%H^QF;AQ+3c9oIR{n)11oVdmF?%e6GT&;!PF@`sADB(4JV8K}#`1ePD^S7CA zb>457m8kUZL(8X=vlKuwiKsMV457n=90t(#8bfNSnp#qMmoU?w&t~7gL^)81u)P6? zDz30@_V0s7F+k2f*^ zdf8gINPEJIksdtF(1rS>kHEH`oH$Q!^~6*ii_8+>0+tmzLwOGCfFifC0+z7>@b<{lFk zUMK<5EnnZ%L>Wy)$0-?O%+im_l=1KbpYo9M6@Ri>c&NvAbOlx zLW>+;nGYn^20!ef-;m{Q;K=yUCx9BVzHQ{w0K8vJpQUZ_NC@9-u(5b(EaG5+NWCl- z&s@=R;@#b*XQdyHk^RQsq^x1;;B~{0+HE*-H)NG$Tx{!y`y*Fbb%+=N4LbroqS8iv zXxZH3LDtgjpDTz=*}oGvZX;vKF?z^5X~QD?LM|@=y0EPEz%OUiI;!<-q6?~M40BH? zt&`ju{1ia7#)=Ux~@?wytYsWa4}N57x1>Nk$qh+cR0IqkU(bGJts(NMEp|UHg464 z`uvM(mXo5<%7p$G{Cx28TX#Xy8`nIm9Z~kz-oLh4W`|5;KrE~Lio@VcCwqsz&gNtu zn#+r$K>OftAaxTKjOfqA7Eun&9R~s)6)@bjfYy(1S51)tyOSh@RF=_3g6BML)u)B> z5v3)Ly4U4v^co!AoP9$)-Em-^Jlbx?TUPVlT!K@Oog7uL=zBQTS!F0`k3b&S$jOo6 zi$fD?7Qc-15|%)b4r-GBX2$wlgI?|YR1L8u?s`w-C4|-2{yCP4s-9ob zt}g9)!1zt`t87PtwVaj`AIYQzY^i1jI5ty=2;%Hj)rTFy;--6ZQ`a074z}8yH#{}Jvd1t!T2NaN?3MOZea23r8Hv=kFKhCQ&%-xWmf?TUSN^q5PykI? zuyaU85v2yf7pM8s_xv%lW~OVtXu|Wt_gRTN$%+YQxq9qWxwY{f=%r;n)^g@q0B2>@ z$(F7=3Wp5y7ks--j=Pz?WpP!+n@L*4*(Ws_obf*>X$m0hlFOkVpZ=qZfANS5(>jq1 zZFp8(8<=CBO^ffGy=qyn80ayvGB6&#fj{)3Uld*5wXHt3u3k-MgbxvEERW7B;AAdc zmn69Fu$8dv5K%NA4WyGu z1HmqUeEs{X`0!(z4l&kef6u=7r^ZLYPa-x<`A!=B4ArnE#jFRbN#o|N_>osP_pi7Z zP%OMz*Qtq~8r>eWLcXLk@iS-MIGmC82$wv(#~T#-+Z{(oo665*&9CT`-RT=|AL^VC zAfH(?1h(o`mLXo@^-|{F^9kYxD?}x5zPC?i@Hx2BAKvzyKYZyOYB3;zXZfcoqtwCM z;a>e>cXk7U91Q!2jPRI1>kK@z&WNykE5~1J?zC%+FmktuuB@ F_&=BlYdrt} literal 0 HcmV?d00001 diff --git a/ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png b/ControlNeXt-SD1.5-Training/examples/conditioning_image_2.png new file mode 100644 index 0000000000000000000000000000000000000000..7b0f14cb8e9321fb0c6fca6d1cdf953306a96b50 GIT binary patch literal 7218 zcmeHM`CC(G7QVRz2p}YEf`$?h3Mg7AP{Se#EefJ+7OS+fsi0CK2`CT>A&HJg1w;f< z2nd!DbgBYU4H!Z$D#|8Htr{k54gw-uWC=?knU6d(&olD}%=63-{E#Pi$$h_b-gC}- zzI$>p(BE5IbE75%LE1jMcYY5+FgS!Ej2ifwrT@7EzEG#W_x=ug&~7pcLHKc>o!{=K zhfVZ~ zNLXu;|GRy{Vhu9YFT@HL1_x(&+;Ff%x#8)7x?!PSeq;AQ5obC3YlE{EKZkj;5I(Gl z@Xv$^>yTdT@^K{h%<S3kt=^i_*k(U;el;Y zLSrcH7>ur89|%!iB*NaNcalgF0=ptj?me?OkH)l_mq2aVXM+mOB%jWnD&xol289mq zwm25yczbbXk}?)7u`0QqWoYAW5}rTjXc!eS9{a?N-3>Nw-TLwMRMnddvq`C0YGUk{ z2-Orur|AVk<1gdk4ic1eP&Mbe{n~MKDVrxd8?NholpAsu zLa8D;TP-BMs5|b0+OlspF3Zoj1!Vm^RMy!^?+yB}RRIKBma0BjNu_@10` zOgT|RkOY+NK^FTS?2b@HsYaa3{qFdzMEDiu@w@IP4s1i1wvBW&5BDp2v>s;R$J=2b2y!%fyJ!3 zq$Jn^pk*(8PDxp4tAgTJ)4`0!W5G0Djci3%(wQHl`{C<@rv9D-Iw;ivov8T*GJDF@ zS)7YZxlnD|0i~-!E+_U?B&=_FYipi-cS9Go){|T0b9%C0fw902H4v!zl_4QR>WLCr@6Gug#3fN)9xyfb{>Gjhnv?_rLy}IkKkc}pX z>}JA-t@`$QphUJEhhHx^IcKDO%p4-+_P^4o`w{>)jx;h=cVG(mX7k4;%lbf^L>4b^ z1ubl$+bBA@t!b-<(!t~O^p140YN{q>ioGmeOE-Z5@`J351 zU^!1iY$P$HxhbaVQA^7s@^^#@HFhnyQ9F|_KTTTkBXn(3|6QGD&(~lFfy+TIDMm9d z`dgQOyEwsiH>?#iv4lx+%nvay{2BEb+l|C&S+g{u|mui!if!E;erD| z-rH42`;Q0pJMHt0!^!M-t<(8-AQ~iV^cM1h5bl#qJSVG;pJkX;<^j+HVL?T2@p1r4 z&R`guvQ=<&AK=)e5F#7m&isv5WIesoLqQ<1>&2-UoGcH0Y(|=Z{y?#~Y=(c$x*E zs?>hUq!GVT^!Krz*06gy*K$^62r)-dXE~zB%xr;(bVMZ<1ejtlTvSRg6$YggrF5q! zJdFQ-^|jO4av)QV6HDzL8i~CrNG7l`^J(2F1L|zj&lRWaQmXt9`0w#RRjP2%bzOub zV)-uV5!nUpnrz_pz8guG|B)&8pGcUDvU}Uy;-dlbd*BL-xa3I@3!B0ygenvfShHPo zrhQ*sSN~l>yI{Q0^+%V@tHNiZC-GulI4LvgPM@&YKHvqs_B!0>a*4R)_UK*hbwRMT ztu;s^5!(~)`-rVNso0cK#dxmP0}4}Q)8^*=*^5~&b)A1`G!{AZSI`*TXu$|~Gu{N^L0Ivel(`dql8W_VV@uM*_v;Elup-oTwCK;szPur=iD!9L8-G3uNF1b9 zn7+cx-#m>8?I+#4cKx-I#KS1+0%!Db>Me%I;ehAKt$;+6$v36R&xAkG?~jQ&kEY{) zQj)k;@&+#-X(W+I?^-IZvO;ZcpL`n%s0_Vo;=~DHoE@;{-*>;U75CCy0(4D+zDigO zL-4|h1%pa;DGT`p`2I`2U&0aUa!}*fzS0zJe?vfGUTV2*ph03d|8^Td@mdA0NWM&Hx$bCvjJy3+9< zz`EC7H>PS;&g;+@riv;?0nHw^Q{u8Y;&W0r=_MoxEXGX zX^K0STbPZE0XN5vg?&E53bU3IyPQRbxJ~du&}>!E{ImA`^K zJMuHPxcls4*Zw+-d;YGUUc zW*lq{$Fq;dUph`;@fn{OiLqGu2db8*KK&;k>jf>?cALOojO!hK<0vx3pM257pdyD` zo<;ZJM02RcmJmKe{P~3|eKQKi0%fv&ScKSO+l0?LW(Pu6mc`q0*%%5)UCKrZ45OkgAi^5BT<-L0`Zn0dimER)FXt6O>?BajNveXh4eV+HVxcCS z;a=38skKJI^FVgzf8@X33EN&nVOK7$f4L7sQ3`(9zRMbI!D>5Zv!X_SKx1~;l|VbF zL)I2>k`auRH67WIgG2(pBH7OEJ@8bn3U;ZsI>8nG2Blh(SCgz9LLy+AX#mA2hm+_t zpbG6zTj(oil|aH*hm#PygazZ)^){^zCp}g>h4R46rUXKJ5*?JYO3;v$Q%W2LBxOTq zrmu2110;J5c8vs(A!{V8k+7x-Yr^<{&>avVl=s5Yw>69A0{#Sqe0KToeDGbwnSTNX C$=dV) literal 0 HcmV?d00001 diff --git a/ControlNeXt-SD1.5-Training/models/controlnext.py b/ControlNeXt-SD1.5-Training/models/controlnext.py new file mode 100644 index 0000000..999189a --- /dev/null +++ b/ControlNeXt-SD1.5-Training/models/controlnext.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.resnet import Downsample2D, ResnetBlock2D + + +class ControlNeXtModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + time_embed_dim = 256, + in_channels = [128, 128], + out_channels = [128, 256], + groups = [4, 8], + controlnext_scale=1. + ): + super().__init__() + + self.time_proj = Timesteps(128, True, downscale_freq_shift=0) + self.time_embedding = TimestepEmbedding(128, time_embed_dim) + self.embedding = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(2, 64), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(2, 64), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(2, 128), + nn.ReLU(), + ) + + self.down_res = nn.ModuleList() + self.down_sample = nn.ModuleList() + for i in range(len(in_channels)): + self.down_res.append( + ResnetBlock2D( + in_channels=in_channels[i], + out_channels=out_channels[i], + temb_channels=time_embed_dim, + groups=groups[i] + ), + ) + self.down_sample.append( + Downsample2D( + out_channels[i], + use_conv=True, + out_channels=out_channels[i], + padding=1, + name="op", + ) + ) + + self.mid_convs = nn.ModuleList() + self.mid_convs.append(nn.Sequential( + nn.Conv2d( + in_channels=out_channels[-1], + out_channels=out_channels[-1], + kernel_size=3, + stride=1, + padding=1 + ), + nn.ReLU(), + nn.GroupNorm(8, out_channels[-1]), + nn.Conv2d( + in_channels=out_channels[-1], + out_channels=out_channels[-1], + kernel_size=3, + stride=1, + padding=1 + ), + nn.GroupNorm(8, out_channels[-1]), + )) + self.mid_convs.append( + nn.Conv2d( + in_channels=out_channels[-1], + out_channels=320, + kernel_size=1, + stride=1, + )) + + self.scale = controlnext_scale + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + ): + + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size = sample.shape[0] + timesteps = timesteps.expand(batch_size) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + sample = self.embedding(sample) + + for res, downsample in zip(self.down_res, self.down_sample): + sample = res(sample, emb) + sample = downsample(sample, emb) + + sample = self.mid_convs[0](sample) + sample + sample = self.mid_convs[1](sample) + + return { + 'output': sample, + 'scale': self.scale, + } + diff --git a/ControlNeXt-SD1.5-Training/models/unet.py b/ControlNeXt-SD1.5-Training/models/unet.py new file mode 100644 index 0000000..3fad26c --- /dev/null +++ b/ControlNeXt-SD1.5-Training/models/unet.py @@ -0,0 +1,1314 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import ( + get_down_block, + get_mid_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + conditional_controls: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for down_idx, downsample_block in enumerate(self.down_blocks): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if down_idx == 0 and conditional_controls is not None: + scale = conditional_controls['scale'] + conditional_controls = conditional_controls['output'] + conditional_controls=nn.functional.adaptive_avg_pool2d(conditional_controls, sample.shape[-2:]) + conditional_controls = conditional_controls.to(sample) + mean_latents, std_latents = torch.mean(sample, dim=(1, 2, 3), keepdim=True), torch.std(sample, dim=(1, 2, 3), keepdim=True) + mean_control, std_control = torch.mean(conditional_controls, dim=(1, 2, 3), keepdim=True), torch.std(conditional_controls, dim=(1, 2, 3), keepdim=True) + conditional_controls = (conditional_controls - mean_control) * (std_latents / (std_control + 1e-12)) + mean_latents + sample = sample + conditional_controls * scale + + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/ControlNeXt-SD1.5-Training/pipeline/pipeline_controlnext.py b/ControlNeXt-SD1.5-Training/pipeline/pipeline_controlnext.py new file mode 100644 index 0000000..362e7cd --- /dev/null +++ b/ControlNeXt-SD1.5-Training/pipeline/pipeline_controlnext.py @@ -0,0 +1,1020 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from models.controlnext import ControlNeXtModel +from models.unet import UNet2DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + +""" + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionControlNeXtPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin +): + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnext: Union[ControlNeXtModel, List[ControlNeXtModel], Tuple[ControlNeXtModel]], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnext=controlnext, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnext_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnext_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnxet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + guess_mode (`bool`, *optional*, defaults to `False`): + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnext = self.controlnext + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnext_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True + image_embeds, negative_image_embeds = self.encode_image( + ip_adapter_image, device, num_images_per_prompt, output_hidden_state + ) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + # 4. Prepare image + if isinstance(controlnext, ControlNeXtModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnext.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_controlnext_compiled = is_compiled_module(self.controlnext) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_controlnext_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + controlnext_output = self.controlnext( + image, + t + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + conditional_controls=controlnext_output, #mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnext.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/ControlNeXt-SD1.5-Training/run_controlnext.py b/ControlNeXt-SD1.5-Training/run_controlnext.py new file mode 100644 index 0000000..326715d --- /dev/null +++ b/ControlNeXt-SD1.5-Training/run_controlnext.py @@ -0,0 +1,350 @@ +import os +import torch +import contextlib +import time +import cv2 +import numpy as np +from PIL import Image +import argparse +from safetensors.torch import load_file +import torch.nn as nn + +from models.unet import UNet2DConditionModel +from models.controlnext import ControlNeXtModel +from pipeline.pipeline_controlnext import StableDiffusionControlNeXtPipeline +from diffusers import UniPCMultistepScheduler, AutoencoderKL +from transformers import AutoTokenizer, PretrainedConfig + +def log_validation( + vae, + text_encoder, + tokenizer, + unet, + controlnext, + args, + device='cuda' +): + + pipeline = StableDiffusionControlNeXtPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnext=controlnext, + safety_checker=None, + revision=args.revision, + variant=args.variant, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(device) + pipeline.set_progress_bar_config() + if args.lora_path is not None: + pipeline.load_lora_weights(args.lora_path) + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + if args.negative_prompt is not None: + negative_prompts = args.negative_prompt + assert len(validation_prompts) == len(validation_prompts) + else: + negative_prompts = None + + image_logs = [] + inference_ctx = torch.autocast(device) + + for i, (validation_prompt, validation_image) in enumerate(zip(validation_prompts, validation_images)): + validation_image = Image.open(validation_image).convert("RGB") + + images = [] + negative_prompt = negative_prompts[i] if negative_prompts is not None else None + + for _ in range(args.num_validation_images): + with inference_ctx: + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator, negative_prompt=negative_prompt + ).images[0] + + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + save_dir_path = os.path.join(args.output_dir, "eval_img") + if not os.path.exists(save_dir_path): + os.makedirs(save_dir_path) + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + formatted_images.append(np.asarray(validation_image)) + for image in images: + formatted_images.append(np.asarray(image)) + formatted_images = np.concatenate(formatted_images, 1) + + file_path = os.path.join(save_dir_path, "{}.png".format(time.time())) + formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB) + print("Save images to:", file_path) + cv2.imwrite(file_path, formatted_images) + + return image_logs + + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNeXt training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnext model or model identifier from huggingface.co/models." + " If not specified controlnext weights are initialized from unet.", + ) + parser.add_argument( + "--unet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained unet model or subset" + ) + parser.add_argument( + "--lora_path", + type=str, + default=None, + help="Path to lora" + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnext_scale", + type=float, + default=1.0, + help="Control level for the controlnext", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnext-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--negative_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--save_load_weights_increaments", + action="store_true", + help=( + "whether to store the weights_increaments" + ), + ) + + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): + if not load_weight_increasement: + if safetensors_path.endswith('.safetensors'): + state_dict = load_file(safetensors_path) + else: + state_dict = torch.load(safetensors_path) + model.load_state_dict(state_dict, strict=strict) + else: + if safetensors_path.endswith('.safetensors'): + state_dict = load_file(safetensors_path) + else: + state_dict = torch.load(safetensors_path) + pretrained_state_dict = model.state_dict() + for k in state_dict.keys(): + state_dict[k] = state_dict[k] + pretrained_state_dict[k] + model.load_state_dict(state_dict, strict=False) + + +if __name__ == "__main__": + args = parse_args() + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant + ) + + text_encoder_cls = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, + args.revision + ) + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant + ) + + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + + controlnext = ControlNeXtModel(controlnext_scale=args.controlnext_scale) + if args.controlnet_model_name_or_path is not None: + load_safetensors(controlnext, args.controlnet_model_name_or_path) + else: + controlnext.scale = 0. + + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant + ) + if args.unet_model_name_or_path is not None: + load_safetensors(unet, args.unet_model_name_or_path, strict=False, load_weight_increasement=args.save_load_weights_increaments) + + + log_validation( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnext=controlnext, + args=args, + ) \ No newline at end of file diff --git a/ControlNeXt-SD1.5-Training/scripts.sh b/ControlNeXt-SD1.5-Training/scripts.sh new file mode 100644 index 0000000..9dd4edf --- /dev/null +++ b/ControlNeXt-SD1.5-Training/scripts.sh @@ -0,0 +1,39 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --main_process_port 1234 train_controlnext.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --output_dir="checkpoints" \ + --dataset_name=fusing/fill50k \ + --resolution=512 \ + --learning_rate=1e-5 \ + --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --checkpoints_total_limit 3 \ + --checkpointing_steps 400 \ + --validation_steps 400 \ + --num_train_epochs 4 \ + --train_batch_size=6 \ + --controlnext_scale 0.35 \ + --save_load_weights_increaments + + + + +CUDA_VISIBLE_DEVICES=4 python run_controlnext.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --output_dir="test" \ + --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --controlnet_model_name_or_path checkpoints/checkpoint-1400/controlnext.bin \ + --unet_model_name_or_path checkpoints/checkpoint-1200/unet.bin \ + --controlnext_scale 0.35 + + + +CUDA_VISIBLE_DEVICES=5 python run_controlnext.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --output_dir="test" \ + --validation_image "examples/conditioning_image_1.png" "examples/conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --controlnet_model_name_or_path checkpoints/checkpoint-400/controlnext.bin \ + --unet_model_name_or_path checkpoints/checkpoint-400/unet_weight_increasements.bin \ + --controlnext_scale 0.35 \ + --save_load_weights_increaments \ No newline at end of file diff --git a/ControlNeXt-SD1.5-Training/train_controlnext.py b/ControlNeXt-SD1.5-Training/train_controlnext.py new file mode 100644 index 0000000..3f1a296 --- /dev/null +++ b/ControlNeXt-SD1.5-Training/train_controlnext.py @@ -0,0 +1,1164 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path +import cv2 +import json +import time + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from models.controlnext import ControlNeXtModel +from pipeline.pipeline_controlnext import StableDiffusionControlNeXtPipeline +from models.unet import UNet2DConditionModel +from safetensors.torch import load_file, save_file +from copy import deepcopy + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.29.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + + +def log_validation( + vae, text_encoder, tokenizer, unet, controlnext, args, accelerator, weight_dtype, step, is_final_validation=False +): + logger.info("Running validation... ") + + if not is_final_validation: + controlnext = accelerator.unwrap_model(controlnext) + unet = accelerator.unwrap_model(unet) + else: + controlnext = ControlNeXtModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + pipeline = StableDiffusionControlNeXtPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnext=controlnext, + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda") + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + + images = [] + + for _ in range(args.num_validation_images): + with inference_ctx: + + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator + ).images[0] + + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + save_dir_path = os.path.join(args.output_dir, "eval_img") + if not os.path.exists(save_dir_path): + os.makedirs(save_dir_path) + for tracker in accelerator.trackers: + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + formatted_images.append(np.asarray(validation_image)) + for image in images: + formatted_images.append(np.asarray(image)) + formatted_images = np.concatenate(formatted_images, 1) + + file_path = os.path.join(save_dir_path, "{}_{}_{}.png".format(step, time.time(), validation_prompt.replace(" ", "-"))) + formatted_images = cv2.cvtColor(formatted_images, cv2.COLOR_BGR2RGB) + print("Save images to:", file_path) + cv2.imwrite(file_path, formatted_images) + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNeXt training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + ) + parser.add_argument( + "--unet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained unet model or model identifier from huggingface.co/models." + " If not specified unet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnext-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--controlnext_scale", + type=float, + default=1.0, + help="The control calse for controlnext.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=1, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--save_load_weights_increaments", + action="store_true", + help=( + "whether to store the weights_increaments" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnext conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnext conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnext", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnext encoder." + ) + + return args + + +def make_train_dataset(args, tokenizer, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None and os.path.splitext(args.dataset_name)[1] != ".json": + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + elif args.dataset_name is not None and os.path.splitext(args.dataset_name)[1] == ".json": + with open(args.dataset_name) as f: + dict = json.load(f) + + train_dataset = datasets.Dataset.from_dict(dict, features=datasets.Features({"image": datasets.Image(), "pose": datasets.Image(), "caption": datasets.Value(dtype='string', id=None)})) + dataset = datasets.DatasetDict({"train": train_dataset}) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if random.random() < args.proportion_empty_prompts: + captions.append("PBH") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ) + return inputs.input_ids + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + examples["input_ids"] = tokenize_captions(examples) + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + } + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + if args.unet_model_name_or_path is not None: + unet.load_state_dict(load_file(args.unet_model_name_or_path)) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnext weights") + controlnext = ControlNeXtModel(controlnext_scale=args.controlnext_scale) + controlnext.load_state_dict(load_file(args.controlnet_model_name_or_path)) + else: + logger.info("Initializing controlnext weights from unet") + controlnext = ControlNeXtModel(controlnext_scale=args.controlnext_scale) + + # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder.requires_grad_(False) + controlnext.train() + unet.train() + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(controlnext).dtype != torch.float32: + raise ValueError( + f"Controlnext loaded as datatype {unwrap_model(controlnext).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = [{ + "params": controlnext.parameters(), + "lr": args.learning_rate * 10 + }] + + pretrained_trainable_params = {} + + for name, para in unet.named_parameters(): + if "to_out" in name: + para.requires_grad = True + para.data = para.to(torch.float32) + params_to_optimize.append({ + "params": para + }) + pretrained_trainable_params[name] = para.detach().cpu() + else: + para.requires_grad = False + + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + train_dataset = make_train_dataset(args, tokenizer, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnext, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnext, unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnext, unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] + + controlnext_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + + controlnext_output = controlnext( controlnext_image, timesteps) + + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=encoder_hidden_states, + conditional_controls=controlnext_output, + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnext.parameters() + # accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + os.makedirs(save_path, exist_ok=True) + + controlnext_path = os.path.join(save_path, 'controlnext.bin') + save_controlnext = accelerator.unwrap_model(deepcopy(controlnext)) + torch.save(save_controlnext.cpu().state_dict(), controlnext_path) + del save_controlnext + + if not args.save_load_weights_increaments: + save_unet = {} + unet_state_dict = accelerator.unwrap_model(unet).state_dict() + unet_path = os.path.join(save_path, 'unet.bin') + for name, paras in pretrained_trainable_params.items(): + trained_paras = deepcopy(unet_state_dict[name]).detach().cpu() + save_unet[name] = trained_paras + torch.save(save_unet, unet_path) + del save_unet + del unet_state_dict + + if args.save_load_weights_increaments: + save_unet = {} + unet_state_dict = accelerator.unwrap_model(unet).state_dict() + unet_path = os.path.join(save_path, 'unet_weight_increasements.bin') + for name, paras in pretrained_trainable_params.items(): + trained_paras = deepcopy(unet_state_dict[name]).detach().cpu() + save_unet[name] = trained_paras - paras + torch.save(save_unet, unet_path) + del save_unet + del unet_state_dict + + + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae, + text_encoder, + tokenizer, + unet, + controlnext, + args, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/README.md b/README.md index d171ddd..68428dd 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ We spent a lot of time to find these. Now share with all of you. May these will - **ControlNeXt-SD1.5** [ [Link](ControlNeXt-SD1.5) ] : Controllable image generation. Our model is built upon [Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Fewer trainable parameters, faster convergence, improved efficiency, and can be integrated with LoRA. -- **ControlNeXt-SD1.5-Training** : The process is quite simple, so we do not plan to invest additional effort into it. You can directly use the HuggingFace examples. Please refer to the SDXL and SVD sections for our newly updated versions! +- **ControlNeXt-SD1.5-Training** [ [Link](ControlNeXt-SD1.5-Training) ] : The training scripts for our `ControlNeXt-SD1.5` [ [Link](ControlNeXt-SD1.5) ]. - **ControlNeXt-SD3** [ [Link](ControlNeXt-SD3) ] : Stay tuned.