From aeea5f00784f318e437579fda7e8f4ef1a83a8f9 Mon Sep 17 00:00:00 2001 From: DoraDong-2023 Date: Tue, 21 May 2024 21:02:46 -0400 Subject: [PATCH] Add functionalities to pipeline Add auto-correction for execution errors. Include classification for complex and simple tasks. Provide an exit option in the API selection step. Add GPT as a side option to support inconvenient code generation, supporting generating codes from GPT in each single turn --- .../components/Chat/LibCardSelect.tsx | 2 + chatbot_ui_biomania/public/.DS_Store | Bin 6148 -> 6148 bytes chatbot_ui_biomania/public/apps/GPT.webp | Bin 0 -> 14634 bytes src/deploy/model.py | 210 ++++++++++++------ src/inference/utils.py | 10 + src/models/__init__.py | 1 + src/models/dialog_classifier.py | 127 +++++++++++ 7 files changed, 286 insertions(+), 64 deletions(-) create mode 100644 chatbot_ui_biomania/public/apps/GPT.webp create mode 100644 src/models/dialog_classifier.py diff --git a/chatbot_ui_biomania/components/Chat/LibCardSelect.tsx b/chatbot_ui_biomania/components/Chat/LibCardSelect.tsx index 985c385..de28cc2 100644 --- a/chatbot_ui_biomania/components/Chat/LibCardSelect.tsx +++ b/chatbot_ui_biomania/components/Chat/LibCardSelect.tsx @@ -20,6 +20,7 @@ export const libImages: { [key: string]: string } = { 'snapatac2': '/apps/snapatac2.webp', 'anndata': '/apps/anndata.webp', //'custom': '/apps/customize.webp', + 'GPT': '/apps/GPT.webp', }; export const LibCardSelect = () => { @@ -57,6 +58,7 @@ export const LibCardSelect = () => { { id: 'snapatac2',name: 'snapatac2' }, { id: 'anndata', name: 'anndata' }, //{ id: 'custom', name: 'custom' }, + { id: 'GPT', name: 'GPT' }, ]; const existingLibIds = methods.map(lib => lib.id); diff --git a/chatbot_ui_biomania/public/.DS_Store b/chatbot_ui_biomania/public/.DS_Store index 144ac804483aab3608169f029a3a9421806e6509..ab4150e409c116f82545fb56ebfb68de310f3bb3 100644 GIT binary patch delta 21 ccmZoMXffC@gNeh?++0V&($H-40;UvE07o7M6aWAK delta 21 ccmZoMXffC@gNeh`$UsNI$iQ;*0;UvE07je!2LJ#7 diff --git a/chatbot_ui_biomania/public/apps/GPT.webp b/chatbot_ui_biomania/public/apps/GPT.webp new file mode 100644 index 0000000000000000000000000000000000000000..8113bb1e4a2ea34f346605f1ca41831fef22f1cb GIT binary patch literal 14634 zcmeHuRa;y^mu@%i?%KG!1q~ep61P`u(0FAo_5AN>n9)e5I5E#BW=LgKq z+|7QTeYLN4tyQ&ND>XS88CqcgKwJ8Qs+Q_I9W(#{fb;fdApqV30boT{S#Ci90Qky8 zFMIjwWd|ud-@yiyTBtm|ir(Ldolpt*fVc}yLS%=Il1ON{SxT<)7-@cx5>XbIFaPL~ z{Lq$|M6rPhuVr7EBN8|qWAICV=@w_|CGzMMAqf5GQE@GuDX{UpI}u%o7Y&z>_x$_n z?b4EZcw@DBh^hELy?%z=+Jc9_zG>~7{c7wHM)3nor^8n-gCkY!19oDf0cJp#$p>#I zYiH7fU95sHcd z+$2w%GTE|Ag&opA5U@dAf(aLVEJ~sFUw(a2)VPRS9f0u5gBd=R{3U|l!8M93?_4fl zj%Y)M?!-ZAOew#5R+ipLEyOsZJYc5gQX1Gxo$-}|e_<$GVRW45Q3)Xt4qfut@)eoZ zp&v3Z=kd|e(HSANuheEB)>sTbjvnhey44ue-wa9MV`RxcI?lZQ_@AHJ5hDPQ7EyBSIEKF zEZO-=_1aLkQ@ymQoFYC6$Gfy%jR0wHDs9|})BcS$TV0NjM9$;65EN;(r_j3bNCBDs zlLPDK?J9$-+b$6d_uW{MZRh1&8~tlhL;?00{I<^ORU3U-?}k$fLdxSA>di)OkW%i#ZbDb8swku1Q=ygrQDBw~+lp~F<12~$sfSBo6D=n_ zz=iGpi9D?L0un^~i2`zKWcRz6$Vp-O-ym#IzUv({BCKG2P6E8sCmGd=+FCl+YAMu^ zU~lq17|&P#Lu@s+$9rF5>D)_o@IL4KNg3<8_oG@{Tn+fU@x3?}Vye`P)i!IBCBY)jLNOAaT3ySk<_n*1SZr zoA?y%-XEinDafW5R9l5t-J-V}BTq16UwTmBT#2)SLX3dl)=EHY zw@TpAOX_S5ukJ0S(^8mS4xXl>Ow-WM(1ygUvP;L9PdfetVnl_Q8G1qbj@H~*jyjC zIZJ=t(uv%?Kpcj?ATzm$Qz6>^$-XG;R)6`dH`N|UVbcXi>c80)uRpg{8^rt5LfyDt zH0`y*H?h4D;~yuD_F9Ols!g^6!MolwR#QEsStG$KUk}Rq4SiyQPp0Gp|5(p8tCncS z;)8w5=tW(`m+I+}haEnmEwy@T|k2?&=D)dS} zIn#f@o$4s@uz~o#>b=*_hrvKDqjzlD5n3VoKO~d`KEG zZBz3i;$B`N*#9aLhuu|_Beq~45-=NnrA5OkHuBddC}-IFyB(ev%I_sNUf?q`WAv^r zeV52IfW_`5R=SCUl?L69fmuirKX?p!7EvLNc*#^@#u;TnndInz^!iUuQ8b#k)CJ$~ z?FU-52);=S(#na0myNI!(^YSExg@UG94x{TtA{@nDo2GQ-`Ghvuu>HA53I|~l>8h8 ziX!@#gh{lP23YkbcRCOI=AmCR;+1sR2k~>rVKo4R*C@cFsS8CU^ha^J@XlECyHlGg zcX8Yb_a4G}YmsTG;Su{w(IAdV3^grT585gL9$C044W5UPYefas^iOa(`_i=cFV!;3 z9P}Ntm{gfY3MOH5@L>5amFq+a_rfSjMPhzk_d`c$?rWM~NOQu59&uj-uH4d>5dA8M zgC9{Le?Ax8R4B)a08xXS^Jax=35$u!sn|umtA-l~Uc&54Cy-TS6S!|*kGvuhSUo)6 z!pOKJWkXc~t+DFOnZNz`a?wH%TS?$4`XMnj0rvt`EltS<^sC+GrP4T<#J-&OhO2a~ zAW!9JQQm3MiKblzw6o%*n()?tvcLZ+W4Uzkc5lxcPQMEy-O3BNM-oe9b4&+h{lxi^ zaA|C#Ivu_WvdTqg-kX&9%6{MPS5j+Jy4RS-d5Z7}4h(jcoNUX2stpla zEO*QBe4H1+yKVm?x6Z=i@@#;(M@`T*i>|@15uFrs#(dEcm7_0#r+!pct}IAeU733` zSlbCzI{qlr@ovAdjBRq2q5Dc@Gqd13Pj#;oLDvMiK0EI1$kv-t$oDVPXCwLg1!RSRy_qRe?qmxVrQ_ODGK~MHp}msoidZirv5>z zk$RREU;eh|R1Ux@)%8Pog?o9+gB5W2f8d~=$rbraj9Y|w-MjN`Y?%@<+n`)Ga#Uf8 zjRmGmAX|y+uCY!C{9xmzmKOC5VlyU7jZRr30{}FhM@A5MiYQk;@Kk7M(1cD6wskl6 zH+&F`^H7=XrJObPB?h%iE|+Yh@vnrYESX|>Iqhxe<7jI#ka1bogea~hoNT6bi|%4` zdNDkpt`el&BNKy2r5zhv00+t0vKd+1@Ypu}>$@RgQNi?F(ZV$_=ivt9eAMN8C9{w% z31L1?*%ac+4qFeXXRoAhC#tvX(Q}>{0gtf{5Wk0U5dY|Sly zhxob?CotIl$V(H%<6jL<5i(N(HElpSkMy;4HCT8^V3`ubvo9jUYcEvj&kv}&#+(yI zpQ0{xCt|@pty4jp^7N~TwcI>7U$RWkxxwQ~m$-WzrC&_jCx0+UB7Hm2H-kQmC)eo| zp5aRg7&s2vKple@#h5mvMz<>9Y>wc_@%v-^!-A9)g$LS0Dds)}tslg@bm!)NJji+A74W_;251{xYbD9THv#8zl0^@Q`>s4sXDJ z#aH=aCUBPy>kIc;<+br8V40GEzmXAb0e!Cg$6bWrT~TOH6*>SL)h03i%VgL!vpo93 zG%k~pUju^!!%K=g=!o4gsaYyrBHRtaD)eiew5DGx*avM(&53499+=?+ENXwtX-C0j zE?jmmK$ayx@i5F^dUDm2Wb&Xl84g9L{A2RgH;X~Snltx$KYW7) z%}g&=v@tzB(40>@4*NHJDND`*(4*Q9d9bqpWYAm2d{F%LTPsxWeI6TgL553R;JWk` zaVHr?d1z#dg;2uvHFh20Egm=XBV+XB3U^uk?lJ9|=|fi(WLaDN7J4c(tsF-Eec08z zh2U-Xk0RbkT!#*r258Az(Q{Sik>rkQQm+E%>E?J^Pj*XD(blKV z&E=X_f-|E!zgE^@^OyrlJEPI_rN&X*Q!K<-@B8uGW$E+qqT?K?QSU$k&sG zs=Ce+7aP*mi#2_O3&b{i)Rq{y%CkuIO|9xo8-%#my~dFg2!~9d!GF443p_+($z?jxRjEgO$s!U?PPR_k`s+d_8LR< zg`J0G^g0T6$>wQZZ1TQ!sbqcw1b_upQ zUBUHKToq}>vRMey&tFlH?T)@@cu6<1jb;dI(Km3g1F_lbhGl~8q!|hU7oINkB$Z#U zmCN}9rF#rnPg&I)`IsM9@#;x0bf97A;wiH{ zG!ozVtzz@LIccpl&-fvAskf5eFVf7W7dHrzOd|5YX>;r!W3ch9;T&}>W6lLc@74W0 zxVg+RjT1Q58+}cS{(^lgsd*j}o2^jVc=}%K#yJDUro86AsGF+h@h!dxk>qxAK$$-l zT3alANpkK;^Z)n=wUUd#jI2D(1!8Bs(7fhjwq5s>fKX}DENhC?=)+tow~@juTvLzG zNs6r$#5{{GXB^|5uacyIXNgHGel=1q(5)}piM*g*l8h%b@zKD7R3}otgEH?+I>q%4 zg{DUBoIOJD7M=E#mLGW7zMzvYA* zm$iR0mDg$=8H@?hQRjM%wkgm1^Nz1pV4x5WbnXjA3o13!c&A3_TKWRlsFJt33%3o> zl0a5nSR~emr=Z_LDg1|nfSv!O!G%09KNIKf!Ci}$X&@UE$2BGkq)I-=Kj?x*R%tjg zS2Ji9@4%UB?f$_p|0`c==d?(>p-g6_3Ed=xEKTSM(-!A_UW5Stl2xGMTyz|;yv+?l z{tP`iR2dOX@d0lF91}ufKy%ikqGW2|3E8&sO1Y>_Q^@52WLb`Fs9&rmea=8~bg~2m zXX-8i4#b#@5dTmCqrfBfV#dOPy1ODe(Pj+_YTbV}Y;oe~8y#Nh6d)DYA> zv%^QeM(^cPvcNCKBqnZIcN-d`5Y%|UYf#im4Z@1UBTSo?i_Y+Ps^wxq#7lo(R}N}= z%;OERC{Kp_%d$;a?wCUSu~C+?Nqp|F5JnKP??P-qpaUnhqUANPeq4v(096{BvTkz{ zu;qq%IJ$sp(}mc31$n2(v^L)^6ub3h_or-zytgl&y(fi+QKLoN?i)KEugJr z`}+U2*#Zg}!1Kif&ZObIP1*$e`8!|62FqqS)z|C~h#a(uHAi>Xl zF6_c_fE(rA5-D@)RO6gFCW6EMJ@+zqqhH&o2=?YrS0cD%!d;vj7BXclV=ZS6m5d6N z!+9-b1*Ry`bYIP$m9Q3wwJ2*?Cvn|(NW%CB^euv=NWMy5_s07NwLMytxATrNKX;Y^ zr3F|2O_;YS79@J0CEOKeRwW9X5=n>{_sjzVaO$&Qur_<4$ZO2oGn(Gn*{UkIc?oq& z3gyhE*N(_#TJ2AMZHly7(uptl66*Kl>i-!Shs5(cn0C+M6N= zi1zQ0jf{Kbqr+adO3CE49>|>RJ+py#^uF+ALSFz1H|dc34=J@3f|2*S<3P4Cb#fa| zz!f{~X;J)AKZSpxt{ju9EN-DW$6Ro`Y z)d45JYe{G5>0ed>O^2!cLL9Re4T#4}q~i&LE-{Q;k!W|gE4>Z4oJeoO9sDV>6;)y! z^xVYlXH?##!6$Hux8U`SQwymxhwb-E;gtg2(;VE_e=14F#4FVc2v5RpSw+9Y8n)mc z?I7Kmx5Yl{@Kt5SvTh}4PIrIP$DzUI^s`jgz{OYx z`Yu_I3M6j+3x+X%b6qfjG;xZ8n?AXd{J6ff`!v~Z6H+@!hJ-WC=H0>=S*b!Qb9^J_ zC8uu~8dKtF!AL3GcuDd;qC%=Vw$36atXDoMrD;}QXs#z7^6*{o1@Fg4C{$$GnoiPJ zjYW%FN(YP>wOM#v6Kt!oHNr>KBJ@MONR=P58qh7y91CYwM24pLE3q-bTx>P~4Wpox z$i}QV1CC3;eHt7il7)i9H{|V9e~KS=!{C!%9)ZNF%8_oF)q_>6W+EQKz~2RYjiO&O zmxrsiI9JC^+GmC;VGT;>vr`*DVtNj%$tC|YB1?YqhGcr@U-#~RQ(OSUGn8r#Cuw6o z_!g;7;1DM$B@r@$Fa6Oz!NkgX?2Svn+W)2eE z$^?^1zs$;s$S#ZY=*E0OW&{fY$+af3Pn^RoU*L>Velcv5%QZ^INki(ydMP0%Zb0q+ z@2HKg2fK`=Ml1evB$!g8Q&q~46WoK#7er}@ZJ~t)*Uu!wj}A<@7z2g z_>lKy%9Y2_8>1)ptDh_tDB^-5> ze4-_hra0wtWYe(ok~ZE$mUU=pwd^-i5VD_2@RV$lt;s@&btu6}q=@U)2)&4%^{z1Z zqYgXKO=bmP``(!fxO?qH1uP4xEn`mVM`8GJyw?}|7F;DVo@uhlppdi~u zirYdS2DuVQuaq4C8|KjNf?e>oJ>BH*rYk0S*fikL&k{3c&Oe zH*mMdg+fdjy^WfA1X&|nKsECio8>N~-4S1jP@S0N6;T&n)??Ip&G-&+0!pfkyRZY# z6y@?PFe9ZJV1wi&kF)!`Iz+@4WxUQ4<KSx9YG0rlY@Rnt>NIV-SoNM22H^! zL;&&z!mjE;6oUZy9NY-6Uz$JO7Uv+Pb`B}D+n#4xA?~xrg)Izy;V-WrI8;<6DM|%-LJ&_%4jVrZVK>$>e8}KsQUp? zj8grqm=?oMk(4@`ruy{dLzq~&2gr7&4W+%SirsH48MvFnSUrx-UXG&2I3!FPQSub0 zhA3MrRUgiO!~~L-I^yb~s**Nd|ZR`-qhIHl^Nxwon@eVto7JUIjsha}owVs8mBU6G4`ds#_ z?OT2Sc#kVHpK*0U1=geHE?Bp*QAghzM}lJgCj!g6Ve4jDruP8 z*lz1G)YTB)MSTaC|CC&i4-c;Qm*o3l-PbTbRFnzkCoZf!r++daK5+fA^79hz$omer z@Mp${9~e%FC~K^o^K$zbi96COgEQ>xt0#J7r}Da=@nU0=Zs6Xe9i?v3Uhbcqe3y}9 zKKsnS2|_fB98F_nhjBU(auF}1#x?BJKi!RhA3u8s)q_gbV%zs zbGp&krFX2KQ5P;m$iO%1GZKF5S3lF#+<*rCJ7q5G09sE8(^K4=_bfiDD^d#{ImiB# z9grKVnPQ1UW+;roz7$=D1&b@}2mF!q@5Ehgp5!8wf?f&vijeticm4!Ri<#p0hmlg^ zQ%f=nV$NQ9c<(QjhW)=}&{$RzIK|o4&6<{-fGw_qxVZEKBK7(?N`1_HmXI7b>F^k#e8V%)aKNTwkJO-yDrX=vjV}wQk)0|;| zEByuvD3fYs{c((3VjoHjCC51C5nb#$lS2~~FrW2BL}eC*f3l$G z4cnRnnEyll7rklP-+EnDXrZ4X1>ptO%MeW#J5P)$*O@S`CaM0@T7v#6-iN8tV@jKD zdPy$y|B<=FEriJ5&O{Q$8Edkom9jj4%}F6zs>*VQuI6*#q}wEEkb~0mBNGK*Iom87vrB?yy*2d|x z*(^q9CXk zr2Voe`|6fC1^L6^{VVZv@*U^5KTy(@>M9mQlF%Al(77G!`I{sN)s5-zSAd@zaY}Zv z1U>pBFhR@PWJOgWnn{%LE~%8OeS%Ms2ce6tJ-mPi|CEnGlNVz#WtXA&?oFEGPTsWN z&6hv%;D)*VV-TTTAcqJ@laLbJO@C~`hQ4n7-eZz^tAHgbfR$c!faZ|#fnb73g8Uon zNn8f!b+<^H0#Keu5cDX39aWp8m}TQKL2Jy1h~?$RH~ACb4PI@TM7#QL#gdXyvVFdW zc_uxvs(dtm`3cXxxX&n+7x;S*T&zvcYE!TR{{dfCOQNk2tnF0DSx)o!U`H)T`0lMD zO-R^tPNHh7x6{zC8=q+R5GV(k>FX0Xo4;sO(EA|T5J%aln4(=Yx{Sn7VS zXQ53{!<{NmZQ_2ubp9gpu6)a9x*36IMf|llf;0y~3;GXFSMfaaKw?pGLzw>ZyE;Rq zFl32zTRw&iFJb2wXazBu-3a##zGqL)sTPt;N5p<&;ymNgit?ruaM#m=f7Q#Y1u@)l z0-a^rnXC(VqA8ivhb)j-cO7sX@>1!A1@oLqI{- zLo@re|1}UK*nRMXVWS=Np`bKi;X?`}QroN$;N?Nt#c{j+dTf%O;g)QEXhy-U+!psZ z%w3V<_f9`LQ!wAxuy?bh6&Do6azKOQ+WFb0MWeU^(xGYErl_H#hFoWAua6_NlnMvC zkTxXt>65ah8D>1hPUxUSSL7cND8vy~hC^KEr0KJze?j`9UD89EgfW<(XIc@qYl?HzzGaGTt0pc*$0CChT9$ z7CSqQwP028$MRo4x^IJLGl)S`R4|0ok!n+-1$TTf2Bslqq6uU4gnXP!)Q6iW+6B0^ zsM<~>k4REvz?+=M4Hlb$q6PiBcUhi8h%OJk_TcfveBAb@@+u3XPi>}~I)Fu1qFCmL zCipPoJef~#^M)Gn{RBb8BzVSeAE{?^00USBnFYb>riNc$1ve)UuWF9ff^c0!kDqYh z`T0o)IFS*3a<+q;Y^Zo4R-h^01I^%)5FOGjI&Xip44E(LD`5s7<<}}Wr-1p*8hhzV zDTig9{5-c*fWhsPH}gaWh5-;bPudSWMTQgDc9xnpCY0n8#qQq_SWbOLY#j{YbVJeN`X=&oIYHM`1Iiqy&2OPp%u9LaJ60(8aSWPmowOm=_08MSJh z*vNTgBtfP(@F{>?7qFcgsaB5jw_*3}Je+5yAQiFM`=VgSGKyETgGMTwS2%4eteGv? zx|T*D3m4A7)9d>nv=L>-NaRm0C(b|?6%V)`vUi{(Pt^hoYU2Ly2-Jcyf27TQg)lTi zkXd{Ln5iU(bj|ds|+Z|IP!tVJdlbS9uK4zG zdt`&I#BuO_lvBttYvf1-ZKm0UQ_uQ{guXmy6J4IWGIykQPTfb<>MmAe96HF*pf#Un zK_SJy7`{em6KSWL_&KO=O~O3&dic`ZfU1Kpp{;7{h8T*++!B&~)N-7Io2*3;^pKcMS&BL7i8 zn$9GV)D|QiSqiC8UzML%l&%d%A!n=3%$U!A!-zW}Ed{s+CCRehZEVx!lP-E1DIdh{ zFF!#3p8xL4L?G&^cnOMMJfOKJH-XXi{}ab>#*L{OzK-UC@3M1?)DHKaMo<}%arr2f zZUr}+Bare3Z^3jx6EB$q4mM`PdGwIauN12(`%QQfzCjR(n9EH7XOiA3{&2jQ`kqnA zH*ht7|GiN5kAkoFx$x37IsavWMM3PNx&VtI=g+=LIV(+P<2e17Z0;RLm=) z$;@LP9ZXCq!f`GmJH`;M0z%T?y%|UJZuIFumw~Wfkv^CydTi=wkCYu?gUHE(C@w92 zs$%=K`=(EcXUd_7&x3Dix2cyN+Va?~14vx6ltIp*>`7%lqt1sf@C!t=Nd5zWFN~|- zKsm8Q=5H~yQ3uSRyxqluZN>MElE@WtgTb>@8R_*ty0g7SV~giw9~R?ex_{V{VME*~ zlqojL*I%z{R>WU^oeaC)U^)#-4H$4+n6-{B`mC+~K?$nbJB;diAHCHf-gG-j$CAlSF(z3_Wq7Nx)=3WH7VMIN`&wx%kUwhlkr)(xhN2 zq8enSFNoj&=iOM}!{9Gao~y3Kb~*1D7QsJ4X0HZyE(MIGfjhX%+tm`mkH}8`i}?QV zqnu`RZPv^VT^ZVPfKdkF3{cMT^Q({khf#Qrpqf$3f}>tm){s2GY~R^ zb?}M$rt-XuL2V|=Csv2p3{nVy>-1O1!>80&v1W7@M=ZWzw&^t%d#ch_?O=&wpl|jm z?~f2mZ?sQl9>JSZt7D8|w*^mnHZlFrGDyYgjxIgVD*t-lLVG4};-+y`btF4Mwz>>- zTKE(GUEOx?dNszl?{bR!G_^i#4x%o(lj9k6j$CS=pHbij0IH*BfGiK_4ydO9t`@0X zUwhvpSuAx_hJ82O?j>kKx~mZ(bZ=#Ek>p5h2_H3j7v8ORxWz|FYlxfYx+(EP8kt%=_TBZBO2pM%(t!W1KFynB9D_F+Y7!Ms7jY+c^6CDDW zhf5i^+?fN?^Z0{$?o$uHTOuK1HnHvbC%DyW3OCyZEAV63+=LL`GvWqRa=&9dve>a`y}NRP7Mf3gQ|baTE`%W>+PPi45g8v zu!d6ITl}OzO(J&=iykIJ(7&~6`@m0_;@B#kLY>c9FG%Z~xQ}FHGL* zWjQM4B3N+D|3Y=_uni;xU+Sq@tyBbl3ICW!)N;Yk;w6i8O7E zoP5HAB&+LBik4Qj??b;~zKAS&W`QP5gjw$zgqB4=>N^J@^zlpJGSK>ts$Y^Dc^8ai zr+%isJQN~fjy%|1f$`rVE|_juX<=t`!XG`BX}c_rXjxM0AgpeZWTvqkD{v}dGQ%fL z;5$1m=5*@@EaUgMi<;VOVD0f!fZwibsa9BMN^yn;0hi{oOJSoR@|mDw)0ijsDgf~s z5V_1YumNRMwf-rNcV-K0Z@ed)B}$v`wJ!g5TwfT|lYB8U$S-X%lDp?V zK}M0a5p76lE?VjEiUn$61d|vDEdK9DrFMTY*fgS#XlJ@30xb6K-aTe#QUK8`5Vl_j zp@PJWh@tsx?Sa7yu`Gm=G8+^0ZD)B=${K;CY}_gslj#T`CrgtVEiGi0*z{@|Z};fV zUd0U|+JA1Bs!G82BbO_QIb*Yc`8H{x&l*-NtsF1C{{$<|ob zvu>%wdh)uL6|j)#dVt0!tuJ#Ml9Mu`i4sd6cyv@%*AdM#@X(%`#fY1q8AB@tE*h_a z2QKiJCp@nNBW2RrV&H$cW?IF#8=vsE`h$i}4iR>P?S0UCx%PoaZqGEcgL;yqwKbH( zP6DcNIIg1D!Z|GP)uuWTs9P_teJqFPU0t0^Kca#1hxc32bV+xIXf2HTVVB%_}%@GoETX<0G3o9Fc-j4BxU4_0%IIJYrK?2*h(0GA9ys z7s^@5DS!Oc)OH)NWw5P{)~Gc&hv3HqsQE2Np9>533mN7%NmqV!zCT*7(EHnC#yBM} z|HNo>)u6(JO!yEjXC0`reUbG;r1?0pVBB5;TVII0)MA&d+sB*2Eq9c>!MXO^Pg;Z!wUIdY_fZl zMqI!I0r`pDzY<k0s%c*5qqz@LJe<*m@}p+)T)v(Ocj6BivBk>G!Cd?aqB?J%Zw} zlf1{S;e|&9hV5_TXSM`l5`JW;h6Le5a004VOmA`JWxN(ct2+bk7NH(2la|3skz8!o zM7v{Ve5Y8+MJv%>vw@Q2%y(aQmcnPh0<{zDcE%;Q|B946V|VoYLnVsp6K1t*fM2OD z$NHXr534U&`zGAqDRbUAK35k{FXE&C=I8 zs_vhdyarojR39a(5Gqi(R}{b-Qy}fXyGh;u*ofk`=ytC5)fG}*P@FtT7v^2QYc=0NF_eM+7jOh^#%PUe6yx-!>dDyb=aLUfek-U5M#!Aw+ebL6x^RRACJQM9E zMK?ivWMsc`>P#+mZde#!)d`ODw<{t&J|_9ja*V1`tPyDOHg zrB9;u8*>QsN|2s!cAE~rQRBh6mLb~K(-IpBg0EB8i6^&tNV0m2D$~Chy%e7`@Y##N zVhakabcaYT6R-W)h1uaDbUl>g?I9@6D(K$%SfDjcP4MQlJ$R^x1LGN6HScE0)`fp9 z$UeRS_^kxC-MropI;ADd|JS^D`ZewiZ*44%gIv9;LB7JCJf+{1!y`Y+JUXF~xuUJ; z{Xz;|XM$Ls*9^tpNh%n*{JL>#wu6iMe3>km!jOWrk3Zmvq zxcc@P^z(uNYBh0gu6n|V4QY7!>12%wrgBG~5I>ci)WsJAKb0{*X*#imk}b7}aXb9N z&h1!i^jL_dp?RUQux&gNc;3j3*G+6Sv$09&>X6bXZ_Xms5FU_=S|bmE>-(H{wGyN>NA8AY9c8Q+76X$;9D`u<)EAaGZU{`2=m4iZPT z_4bw}Z1~Wd*^j=Z46Iszd_MRMoY*)EeZc*(MkeXjHd+}K3>&@HR`nxjsRJK49P2^f zQapF`(MG_}4#v?P6~%i{*yJ%W)X&meNGK=)^3#WRmmM}aye3NmA8$ZME2pGF2jpKb zBC#`NW^Rze!mEaB_A={sE;?leqWKr$pw5l84goWsC^bx}3B*2+ z!)abh8Fg$Uv~h6gHh6zsK)_(!vpzdZ;$K61o3%Y@2Pa<#l>xnfLU+zmR`64x~93hpcN`kpw7W-lq{sO;-^pxs0k z%n?@1zdXp^=U73z_-luklKoj-D)u{c<-I|LuEQQx)C48V>@j13RS|)?!fj(e2tSNt zqm*wpY<-GpE~fjHE0*Uq5H3U_Udch%7tPDb4|a|rj2{J+x$+Y99d*g$3(u&s7g&U= zBbVfc?UMPLmpf{56Kv`HtKF@1Bv<`j@G-=|f+ZH#GE{7pVv%)-I`IBUsk!{OlRq<6 z{UjD*QS5EdE^Ta{Hi7y(CWPqZKVLfvk~9}!Eh7o>yaYQsS8xO$e^Yx(0mDWbVJ-3( zCVv^mGN4$|jeh49_BQqve|@Fu(l^D^2uGcGB`~5#Fi8KHiBz;Vn%PO#WgZeyS5BFj zh>O+yz$O^7@jhP|_<4@I9OozSja*9Ddh zbOi46`P)r0&^R0bFs%m22C{xfkOWFN2!4|*q@&FPp#Xv~qijC!c$$z)qMpKU*?*7d zToK0huXmyPYrZdjH!|#dmYi_E#zb?`q%8G+pZ{-J{+HTeM-uSzdh&MH@qhOG7cJjy A=>Px# literal 0 HcmV?d00001 diff --git a/src/deploy/model.py b/src/deploy/model.py index 6dfbbf6..35a675e 100644 --- a/src/deploy/model.py +++ b/src/deploy/model.py @@ -15,10 +15,22 @@ from ..prompt.summary import prepare_summary_prompt, prepare_summary_prompt_full from ..configs.Lib_cheatsheet import CHEATSHEET as LIB_CHEATSHEET from ..deploy.utils import basic_types, generate_api_calling, download_file_from_google_drive, download_data, save_decoded_file, correct_bool_values, convert_bool_values, infer, dataframe_to_markdown, convert_image_to_base64, change_format, parse_json_safely, post_process_parsed_params, special_types, io_types, io_param_names +from ..models.dialog_classifier import Dialog_Gaussian_classificaiton +def make_execution_correction_prompt(user_input, history_record, error_code, error_message, variables, LIB): + prompt = f"Your task is to correct a Python code snippet based on the provided information. The user's inquiry is represented by '{user_input}'. The history of successful executions is provided in '{history_record}', and variables in the namespace are supplied in a dictionary '{variables}'. Execute the error code snippet '{error_code}' and capture the error message '{error_message}'. Analyze the error to determine its root cause. Then, using the entire API name instead of abbreviation in the format '{LIB}.xx.yy'. Ensure any new variables created are named with the prefix 'result_' followed by digits, without reusing existing variable names. If you feel necessary to perform attribute operations similar to 'result_1.var_names.intersection(result_2.var_names)' or subset an AnnData object by columns like 'adata_ref = adata_ref[:, var_names]', go ahead. If you feel importing some libraries are necessary, go ahead. Maintain consistency with the style of previously executed code. Ensure that the corrected code, given the variables in the namespace, can be executed directly without errors. Return the corrected code snippet in the format: '\"\"\"your_corrected_code\"\"\"'. Do not include additional descriptions." + # please return minimum line of codes that you think is necessary to execute for the task related inquiry + return prompt + +def make_GPT_prompt(user_input, history_record, variables, LIB): + prompt = f"Your task is to generate a Python code snippet based on the provided information. The user's inquiry is represented by '{user_input}'. The history of successful executions is provided in '{history_record}', and variables in the namespace are supplied in a dictionary '{variables}'. Analyze the user intent about the task to generate code. Then, using the entire API name instead of abbreviation in the format '{LIB}.xx.yy'. Ensure any new variables created are named with the prefix 'result_' followed by digits, without reusing existing variable names. If you feel necessary to perform attribute operations similar to 'result_1.var_names.intersection(result_2.var_names)' or subset an AnnData object by columns like 'adata_ref = adata_ref[:, var_names]', go ahead. If you feel importing some libraries are necessary, go ahead. Maintain consistency with the style of previously executed code. Ensure that the generated code, given the variables in the namespace, can be executed directly without errors. Return the corrected code snippet in the format: '\"\"\"your_corrected_code\"\"\"'. Do not include additional descriptions." + # please return minimum line of codes that you think is necessary to execute for the task related inquiry + return prompt class Model: def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama3 print('start initialization!') + self.retry_execution_limit = 3 + self.retry_execution_count = 0 self.path_info_list = ['path','Path','PathLike'] self.model_llm_type = model_llm_type self.logger = logger @@ -33,8 +45,8 @@ def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama self.LIB = "scanpy" self.args_retrieval_model_path = f'./hugging_models/retriever_model_finetuned/{self.LIB}/assigned' self.args_top_k = 3 - self.param_gpt_retry = 1 - self.predict_api_gpt_retry = 3 + self.param_llm_retry = 1 + self.predict_api_llm_retry = 3 self.session_id = "" #load_dotenv() OPENAI_API_KEY = os.getenv('OPENAI_API_KEY', 'sk-test') @@ -62,7 +74,7 @@ def __init__(self, logger, device, model_llm_type="gpt-3.5-turbo-0125"): # llama # self.vectorizer = pickle.load(f) with open(f'./data/standard_process/{self.LIB}/centroids.pkl', 'rb') as f: self.centroids = pickle.load(f) - self.retrieve_query_mode = "similar" + self.retrieve_query_mode = "random" self.all_apis, self.all_apis_json = get_all_api_json(f"./data/standard_process/{self.LIB}/API_init.json", mode='single') print("Server ready") def load_multiple_corpus_in_namespace(self, ): @@ -77,6 +89,16 @@ def load_multiple_corpus_in_namespace(self, ): self.executor.execute_api_call(f"warnings.filterwarnings('ignore')", "import") def load_bert_model(self, load_mode='unfinetuned_bert'): self.bert_model = SentenceTransformer('all-MiniLM-L6-v2', device=self.device) if load_mode=='unfinetuned_bert' else SentenceTransformer(f"./hugging_models/retriever_model_finetuned/{self.LIB}/assigned", device=self.device) + def compute_dialog_metrics(self,): + annotate_path = f'data/standard_process/{self.LIB}/API_inquiry_annotate.json' + annotate_data = load_json(annotate_path) + info_json = get_all_variable_from_cheatsheet(self.LIB) + LIB_ALIAS = info_json['LIB_ALIAS'] + self.dialog_p_threshold = 0.05 + data_source = "single_query_train" + self.dialog_classifer = Dialog_Gaussian_classificaiton(threshold=self.dialog_p_threshold) + scores_train, outliers = self.dialog_classifer.compute_accuracy_filter_compositeAPI(self.LIB, self.retriever, annotate_data, self.args_top_k, name=data_source, LIB_ALIAS=LIB_ALIAS) + self.dialog_mean, self.dialog_std = self.dialog_classifer.fit_gaussian(scores_train['rank_1']) def reset_lib(self, lib_name): #lib_name = lib_name.strip() self.logger.debug("================") @@ -115,6 +137,8 @@ def reset_lib(self, lib_name): self.logger.info("loading model cost: {} s", str(time.time()-t1)) reset_result = "Success" self.LIB = lib_name + # compute the dialog metrics + self.compute_dialog_metrics() except Exception as e: self.logger.error("at least one data or model is not ready, please install lib first!") self.logger.error("Error: {}", e) @@ -161,7 +185,7 @@ def install_lib_simple(self,lib_name, lib_alias, api_html=None, github_url=None, self.callback_func('installation', "Preparing instruction generation API_inquiry.json ...", "52") command = [ "python", "-m", "src.dataloader.preprocess_retriever_data", - "--LIB", self.LIB, "--GPT_model", "got3.5" + "--LIB", self.LIB, "--GPT_model", "gpt3.5" ] subprocess.Popen(command) ########### @@ -291,7 +315,9 @@ def install_lib_full(self,lib_name, lib_alias, api_html=None, github_url=None, d #cheatsheet_data.update(new_lib_details) # save_json(cheatsheet_path, cheatsheet_data) # TODO: need to save tutorial_github and tutorial_html_path to cheatsheet - + def save_state_enviro(self): + self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") + self.save_state() def update_image_file_list(self): return [f for f in os.listdir(self.image_folder) if f.endswith(".webp")] def load_composite_code(self, lib_name): @@ -306,10 +332,6 @@ def load_composite_code(self, lib_name): function_name = node.name function_body = ast.unparse(node) self.functions_json[function_name] = function_body - def retrieve_names(self,query): - retrieved_names = self.retriever.retrieving(query, top_k=self.args_top_k) - self.logger.info("retrieved_names: {}", retrieved_names) - return retrieved_names def initialize_executor(self): self.executor = CodeExecutor(self.logger) self.executor.callbacks = self.callbacks @@ -406,7 +428,7 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T self.initialize_executor() pass # only reset lib when changing lib - if lib!=self.LIB: + if lib!=self.LIB and lib!='GPT': reset_result = self.reset_lib(lib) if reset_result=='Fail': self.logger.error('Reset lib fail! Exit the dialog!') @@ -414,6 +436,8 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T return self.args_retrieval_model_path = f'./hugging_models/retriever_model_finetuned/{lib}/assigned' self.LIB = lib + elif lib=='GPT': + self.update_user_state("run_pipeline_asking_GPT") # only clear namespace when starting new conversations if conversation_started in ["True", True]: self.logger.info('==>new conversation_started!') @@ -434,6 +458,7 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T self.loading_data(files) self.query_id += 1 self.user_query = user_input + # chitchat prediction predicted_source = infer(self.user_query, self.bert_model, self.centroids, ['chitchat-data', 'topical-chat', 'api-query']) self.logger.info('----query inferred as {}----', predicted_source) if predicted_source!='api-query': @@ -443,7 +468,16 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T return else: pass - retrieved_names = self.retrieve_names(user_input) + # dialog prediction + self.logger.info('start predicting whether inquiry is a complex task description or simple one!') + pred_class = self.dialog_classifer.single_prediction(user_input, self.retriever, self.args_top_k) + self.logger.info('----query inferred as {}----', pred_class) + if pred_class not in ['single']: + # TODO: retrieve multiple APIs + + pass + # start retrieving names + retrieved_names = self.retriever.retrieving(user_input, top_k=self.args_top_k) # produce prompt if self.retrieve_query_mode=='similar': instruction_shot_example = self.retriever.retrieve_similar_queries(user_input, shot_k=5) @@ -473,27 +507,27 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T retrieved_apis_prepare+=f"{idx}:" + api+", description: "+self.all_apis_json[api].replace('\n',' ')+"\n" api_predict_prompt = api_predict_init_prompt.format(query=user_input, retrieved_apis=retrieved_apis_prepare, similar_queries=instruction_shot_example) success = False - for _ in range(self.predict_api_gpt_retry): + for _ in range(self.predict_api_llm_retry): try: response, _ = LLM_response(api_predict_prompt, self.model_llm_type, history=[], kwargs={}) # llm - self.logger.info('==>Ask GPT: {}\n==>GPT response: {}', api_predict_prompt, response) - # hack for if GPT answers this or that + self.logger.info('==>Ask LLM: {}\n==>LLM response: {}', api_predict_prompt, response) + # hack for if LLM answers this or that """response = response.split(',')[0].split("(")[0].split(' or ')[0] response = response.replace('{','').replace('}','').replace('"','').replace("'",'') - response = response.split(':')[0]# for robustness, sometimes gpt will return api:description""" + response = response.split(':')[0]# for robustness, sometimes llm will return api:description""" response = correct_pred(response, self.LIB) response = response.strip() #self.logger.info('self.all_apis_json keys: {}', self.all_apis_json.keys()) self.logger.info('response in self.all_apis_json: {}', response in self.all_apis_json) self.all_apis_json[response] - self.predicted_api_name = response + self.predicted_api_name = response success = True break except Exception as e: self.logger.error('error during api prediction: {}', e) if not success: self.initialize_tool() - self.callback_func('log', "GPT can not return valid API name prediction, please redesign your prompt.", "GPT predict Error") + self.callback_func('log', "LLM can not return valid API name prediction, please redesign your prompt.", "LLM predict Error") return self.logger.info('length of ambiguous api list: {}',len(self.ambiguous_api)) # if the predicted API is in ambiguous API list, then show those API and select one from them @@ -511,21 +545,50 @@ def run_pipeline(self, user_input, lib, top_k=3, files=[],conversation_started=T next_str+='\n'+description_1 self.update_user_state("run_pipeline_after_ambiguous") idx_api+=1 + next_str+="\n"+f"Candidate [-1]: No appropriate candidate, restart another inquiry by input -1" + self.update_user_state("run_pipeline_after_ambiguous") self.callback_func('log', next_str, f"Can you confirm which of the following {len(self.filtered_api)} candidates") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() else: self.update_user_state("run_pipeline_after_fixing_API_selection") self.run_pipeline_after_fixing_API_selection(user_input) elif self.user_states == "run_pipeline_after_ambiguous": ans = self.run_pipeline_after_ambiguous(user_input) if ans in ['break']: + self.logger.info('break the loop! Restart the inquiry') return self.run_pipeline_after_fixing_API_selection(user_input) - elif self.user_states in ["run_pipeline_after_doublechecking_execution_code", "run_pipeline_after_entering_params", "run_select_basic_params", "run_pipeline_after_select_special_params", "run_select_special_params", "run_pipeline_after_doublechecking_API_selection"]: + elif self.user_states in ["run_pipeline_after_doublechecking_execution_code", "run_pipeline_after_entering_params", "run_select_basic_params", "run_pipeline_after_select_special_params", "run_select_special_params", "run_pipeline_after_doublechecking_API_selection", "run_pipeline_asking_GPT"]: self.handle_state_transition(user_input) else: + self.logger.error('Unknown user state: {}', self.user_states) raise ValueError + def run_pipeline_asking_GPT(self,user_input): + self.logger.info('==>run_pipeline_asking_GPT') + self.retry_execution_count +=1 + if self.retry_execution_count>self.retry_execution_limit: + self.logger.error('retry_execution_count exceed the limit! Exit the dialog!') + self.initialize_tool() + self.callback_func('log', 'code generation using GPT has exceed the limit! Please choose other lib and re-enter the inquiry! You can use GPT again once you have executed code successfully through our tool!', 'Error') + self.update_user_state('run_pipeline') + self.save_state_enviro() + return + self.initialize_tool() + self.logger.info('==>start asking GPT for code generation!') + prompt = make_GPT_prompt(self.user_query, str(self.executor.execute_code), str(self.executor.variables), self.LIB) + response, _ = LLM_response(prompt, self.model_llm_type, history=[], kwargs={}) + newer_code = response.replace('\"\"\"', '') + self.execution_code = newer_code + self.callback_func('code', self.execution_code, "Executed code") + # LLM response + summary_prompt = prepare_summary_prompt_full(self.user_query, self.predicted_api_name, self.API_composite[self.predicted_api_name]['description'], self.API_composite[self.predicted_api_name]['Parameters'],self.API_composite[self.predicted_api_name]['Returns'], self.execution_code) + response, _ = LLM_response(summary_prompt, self.model_llm_type, history=[], kwargs={}) + self.callback_func('log', response, "Task summary before execution") + self.callback_func('log', "Could you confirm whether this task is what you aimed for, and the code should be executed? Please enter y/n.\nIf you press n, then we will re-direct to the parameter input step", "Double Check") + self.update_user_state("run_pipeline_after_doublechecking_execution_code") + self.save_state_enviro() + return + def handle_unknown_state(self, user_input): self.logger.info("Unknown state: {}", self.user_states) @@ -544,15 +607,21 @@ def run_pipeline_after_ambiguous(self,user_input): self.update_user_state("run_pipeline_after_ambiguous") return 'break' try: - self.filtered_api[int(user_input)-1] + if int(user_input)==-1: + self.update_user_state("run_pipeline") + self.callback_func('log', "We will start another round. Could you re-enter your inquiry?", "Start another round") + self.logger.info("user state updated to run_pipeline") + self.save_state_enviro() + return 'break' + else: + self.filtered_api[int(user_input)-1] except: self.callback_func('log', "Error: the input index exceed the maximum length of ambiguous API list\nPlease re-enter the index", "Index Error") self.update_user_state("run_pipeline_after_ambiguous") return 'break' self.update_user_state("run_pipeline_after_fixing_API_selection") self.predicted_api_name = self.filtered_api[int(user_input)-1] - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() def process_api_info(self, api_info, single_api_name): relevant_apis = api_info.get(single_api_name, {}).get("relevant APIs") if not relevant_apis: @@ -605,8 +674,7 @@ def run_pipeline_after_fixing_API_selection(self,user_input): self.callback_func('log', response, f"Predicted API: {self.predicted_api_name}") self.callback_func('log', "Could you confirm whether this API should be called? Please enter y/n.", "Double Check") self.update_user_state("run_pipeline_after_doublechecking_API_selection") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() def run_pipeline_after_doublechecking_API_selection(self, user_input): self.logger.info('==>run_pipeline_after_doublechecking_API_selection') @@ -619,8 +687,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): self.initialize_tool() self.logger.info("user tool initialized") self.callback_func('log', "We will start another round. Could you re-enter your inquiry?", "Start another round") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() return else: self.logger.info("user input is y") @@ -629,8 +696,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): self.logger.info('input is not y or n') self.initialize_tool() self.callback_func('log', "The input was not y or n, please enter the correct value.", "Index Error") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() # user_states didn't change return self.logger.info("==>Need to collect all parameters for a composite API") @@ -670,7 +736,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): } for param in api_parameters_information ] - #filter out special type parameters, do not infer them using gpt + #filter out special type parameters, do not infer them using LLM api_parameters_information = [param for param in api_parameters_information if any(basic_type in param['type'] for basic_type in basic_types)] parameters_name_list = [param_info['name'] for param_info in api_parameters_information] apis_description = "" @@ -692,15 +758,15 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): except Exception as e: self.logger.error('error for parameters: {}', e) if len(parameters_name_list)==0: - # if there is no required parameters, skip using gpt + # if there is no required parameters, skip using LLM response = "[]" predicted_parameters = {} else: success = False - for _ in range(self.param_gpt_retry): + for _ in range(self.param_llm_retry): try: response, _ = LLM_response(parameters_prompt, self.model_llm_type, history=[], kwargs={}) - self.logger.info('==>Asking GPT: {}, ==>GPT response: {}', parameters_prompt, response) + self.logger.info('==>Asking LLM: {}, ==>LLM response: {}', parameters_prompt, response) returned_content_str_new = response.replace('null', 'None').replace('None', '"None"') # 240519 fix pred_params, success = parse_json_safely(returned_content_str_new) @@ -710,7 +776,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): pass self.logger.info('success or not: {}', success) if not success: - self.callback_func('log', "GPT can not return valid parameters prediction, please redesign prompt in backend if you want to predict parameters. We will skip parameters prediction currently", "GPT predict Error") + self.callback_func('log', "LLM can not return valid parameters prediction, please redesign prompt in backend if you want to predict parameters. We will skip parameters prediction currently", "LLM predict Error") response = "{}" predicted_parameters = {} self.logger.info('predicted_parameters: {}', predicted_parameters) @@ -770,8 +836,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): self.initialize_tool() self.callback_func('log', "However, there are still some parameters with special type undefined. Please start from uploading data, or check your parameter type in json files.", "Missing Parameters: special type") self.update_user_state("run_pipeline") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() return # $ param if multiple choice multiple_dollar_value_params = [param_name for param_name, param_info in self.selected_params.items() if ('list' in str(type(param_info["value"]))) and (len(param_info["value"])>1)] @@ -788,8 +853,7 @@ def run_pipeline_after_doublechecking_API_selection(self, user_input): self.callback_func('log', f"The predicted API takes {tmp_input_para} as input. However, there are still some parameters undefined in the query.", "Enter Parameters: special type", "red") self.update_user_state("run_select_special_params") self.run_select_special_params(user_input) - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() return self.run_pipeline_after_select_special_params(user_input) @@ -814,8 +878,7 @@ def run_select_special_params(self, user_input): self.update_user_state("run_select_special_params") del self.filtered_params[self.last_param_name] #print('self.filtered_params: {}', json.dumps(self.filtered_params)) - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() return elif len(self.filtered_params)==1: self.last_param_name = list(self.filtered_params.keys())[0] @@ -827,12 +890,10 @@ def run_select_special_params(self, user_input): self.update_user_state("run_pipeline_after_select_special_params") del self.filtered_params[self.last_param_name] #print('self.filtered_params: {}', json.dumps(self.filtered_params)) - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() else: self.callback_func('log', "The parameters candidate list is empty", "Error Enter Parameters: basic type", "red") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() raise ValueError def run_pipeline_after_select_special_params(self,user_input): @@ -857,8 +918,7 @@ def run_pipeline_after_select_special_params(self,user_input): self.callback_func('log', f"The predicted API takes {tmp_input_para} as input. However, there are still some parameters undefined in the query.", "Enter Parameters: basic type", "red") self.user_states = "run_select_basic_params" self.run_select_basic_params(user_input) - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() return self.run_pipeline_after_entering_params(user_input) @@ -873,21 +933,18 @@ def run_select_basic_params(self, user_input): self.callback_func('log', "Which value do you think is appropriate for the parameters '" + self.last_param_name + "'?", "Enter Parameters: basic type","red") self.update_user_state("run_select_basic_params") del self.filtered_params[self.last_param_name] - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() return elif len(self.filtered_params)==1: self.last_param_name = list(self.filtered_params.keys())[0] self.callback_func('log', "Which value do you think is appropriate for the parameters '" + self.last_param_name + "'?", "Enter Parameters: basic type", "red") self.update_user_state("run_pipeline_after_entering_params") del self.filtered_params[self.last_param_name] - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() else: # break out the pipeline self.callback_func('log', "The parameters candidate list is empty", "Error Enter Parameters: basic type","red") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() raise ValueError def split_params(self, selected_params, parameters_list): extracted_params = [] @@ -1026,8 +1083,7 @@ def run_pipeline_after_entering_params(self, user_input): self.callback_func('log', response, "Task summary before execution") self.callback_func('log', "Could you confirm whether this task is what you aimed for, and the code should be executed? Please enter y/n.\nIf you press n, then we will re-direct to the parameter input step", "Double Check") self.update_user_state("run_pipeline_after_doublechecking_execution_code") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() def run_pipeline_after_doublechecking_execution_code(self, user_input): self.initialize_tool() @@ -1038,8 +1094,7 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input): #self.user_states = "run_pipeline" self.update_user_state("run_pipeline_after_doublechecking_API_selection")#TODO: check if exist issue self.callback_func('log', "We will redirect to the parameters input", "Re-enter the parameters") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() self.run_pipeline_after_doublechecking_API_selection('y') return else: @@ -1047,8 +1102,7 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input): else: self.logger.info('input not y or n') self.callback_func('log', "The input was not y or n, please enter the correct value.", "Index Error") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.save_state_enviro() # user_states didn't change return # else, continue @@ -1087,7 +1141,12 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input): try: content = '\n'.join(output_list) except: - content = "" + try: + content = self.last_execute_code['error'] + except: + content = "" + self.logger.info('content: {}', content) + self.logger.info('self.last_execute_code: {}',str(self.last_execute_code)) # show the new variable if self.last_execute_code['success']=='True': # if execute, visualize value @@ -1149,11 +1208,35 @@ def run_pipeline_after_doublechecking_execution_code(self, user_input): self.image_file_list = new_img_list if tips_for_execution_success: # if no output, no new variable, present the log self.callback_func('log', str(content), "Executed results [Success]") + self.retry_execution_count = 0 else: self.logger.info('Execution Error: {}', content) - self.callback_func('log', "\n".join(list(set(output_list))), "Executed results [Fail]") - self.executor.save_environment(f"./tmp/sessions/{str(self.session_id)}_environment.pkl") - self.save_state() + self.callback_func('log', content, "Executed results [Fail]") + if self.retry_execution_count= len(api_data)] + api_names = [item['api_calling'][0].split('(')[0] for item in diff_data] + api_counts = Counter(api_names) + print(f"Total differences found: {len(diff_data)}") + print(f"Unique API functions: {len(api_counts)}") + assert max(api_counts.values()) - min(api_counts.values()) <= 0, "API function distribution is not even" + return diff_data def json_to_docstring(api_name, description, parameters): params_list = ', '.join([ diff --git a/src/models/__init__.py b/src/models/__init__.py index 25efac6..5cd1490 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -5,3 +5,4 @@ #from . import train_classification from . import train_retriever #from . import lit_llama +from . import dialog_classifier diff --git a/src/models/dialog_classifier.py b/src/models/dialog_classifier.py new file mode 100644 index 0000000..852f587 --- /dev/null +++ b/src/models/dialog_classifier.py @@ -0,0 +1,127 @@ +""" +Author: Zhengyuan Dong +Date Created: May 06, 2024 +Last Modified: May 21, 2024 +Description: compare the tutorial summary query and single query retrieval results +""" + +import numpy as np +from scipy.stats import norm +from sklearn.metrics import accuracy_score +import matplotlib.pyplot as plt +import seaborn as sns + +import os +from sentence_transformers import util +from tqdm import tqdm + +from src.gpt.utils import load_json, save_json + +class Dialog_Gaussian_classificaiton: + def __init__(self, threshold=0.05): + self.threshold = threshold + + def fit_gaussian(self, data): + self.mean = np.mean(data) + self.std = np.std(data) + return self.mean, self.std + + def calculate_p_values(self, scores, mean, std): + return [norm.cdf(score, mean, std) for score in scores] + + def classify_based_on_p(self, p_values, threshold=0.05): + return [1 if p < threshold else 0 for p in p_values] + + def classify(self, rank_1_scores): + p_values_val = self.calculate_p_values(rank_1_scores, self.mean, self.std) + predictions_val = self.classify_based_on_p(p_values_val, threshold=self.threshold) + return predictions_val + + def compute_acc(self, labels, predictions): + return accuracy_score(labels, predictions) + + def plot_boxplot(self, data, title, LIB): + plt.figure(figsize=(10, 6)) + sns.boxplot(data=data) + plt.title(title) + plt.xticks(ticks=range(5), labels=[f'Rank {i+1}' for i in range(5)]) + plt.ylabel('Score') + plt.savefig(f'./plot/{LIB}/avg_retriever_{title}.pdf') + + def compute_accuracy_filter_compositeAPI(self, LIB, retriever, data, retrieved_api_nums, name='train', LIB_ALIAS='scanpy', verbose=False, filter_composite=True): + # remove class type API, and composite API from the data + API_composite = load_json(os.path.join(f"data/standard_process/{LIB}","API_composite.json")) + data_to_save = [] + scores_rank_1 = [] + scores_rank_2 = [] + scores_rank_3 = [] + scores_rank_4 = [] + scores_rank_5 = [] + outliers = [] + total_api_non_composite = 0 + total_api_non_ambiguous = 0 + query_to_api = {} + query_to_retrieved_api = {} + query_to_all_scores = {} + for query_data in tqdm(data): + retrieved_apis = retriever.retrieving(query_data['query'], top_k=retrieved_api_nums+20) + if filter_composite: + retrieved_apis = [i for i in retrieved_apis if i.startswith(LIB_ALIAS) and API_composite[i]['api_type']!='class' and API_composite[i]['api_type']!='unknown'] + retrieved_apis = retrieved_apis[:retrieved_api_nums] + assert len(retrieved_apis)==retrieved_api_nums + query_to_retrieved_api[query_data['query']] = retrieved_apis + try: + query_to_api[query_data['query']] = query_data['api_calling'][0].split('(')[0] + except: + pass + query_embedding = retriever.embedder.encode(query_data['query'], convert_to_tensor=True) + hits = util.semantic_search(query_embedding, retriever.corpus_embeddings, top_k=5, score_function=util.cos_sim) + if len(hits[0]) > 0: + scores_rank_1.append(hits[0][0]['score']) + if len(hits[0]) > 1: + scores_rank_2.append(hits[0][1]['score']) + if len(hits[0]) > 2: + scores_rank_3.append(hits[0][2]['score']) + if len(hits[0]) > 3: + scores_rank_4.append(hits[0][3]['score']) + if len(hits[0]) > 4: + scores_rank_5.append(hits[0][4]['score']) + scores = [hit['score'] for hit in hits[0]] if hits[0] else [] + query_to_all_scores[query_data['query']] = scores + # Compute average scores for each rank + scores = { + "rank_1": scores_rank_1, + "rank_2": scores_rank_2, + "rank_3": scores_rank_3, + "rank_4": scores_rank_4, + "rank_5": scores_rank_5 + } + q1, q3 = np.percentile(scores_rank_1, [25, 75]) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + for i, score in enumerate(scores_rank_1): + if score < lower_bound or score > upper_bound: + try: + outliers.append({'index': i, 'score': score, 'query': data[i]['query'], 'retrieved_apis': query_to_retrieved_api[data[i]['query']], 'query_api': query_to_api[data[i]['query']], 'all_scores': query_to_all_scores[data[i]['query']]}) + if verbose: + print(f"{name} Outlier detected: Score = {score}, Query = {data[i]['query']}, retrieved_apis = {query_to_retrieved_api[data[i]['query']]}, query_api = {query_to_api[data[i]['query']]}, score = {query_to_all_scores[data[i]['query']]}") + except: + pass + return scores, outliers + def single_prediction(self, query, retriever, top_k): + query_embedding = retriever.embedder.encode(query, convert_to_tensor=True) + hits = util.semantic_search(query_embedding, retriever.corpus_embeddings, top_k=top_k, score_function=util.cos_sim) + if len(hits[0]) > 0: + score_rank1 = hits[0][0]['score'] + # TODO: need to load the threshold for the score_rank1 to distinguish whether it is a dialog + pred_label = self.classify([score_rank1]) + if pred_label==1: + pred_class = 'multiple' + else: + pred_class = 'single' + return pred_class + +import inspect +__all__ = list(set([name for name, obj in locals().items() if not name.startswith('_') and (inspect.isfunction(obj) or (inspect.isclass(obj) and name != '__init__') or (inspect.ismethod(obj) and not name.startswith('_')))])) +