diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index a94f6a8..836b05e 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -1,10 +1,14 @@ -use std::sync::{Arc, LazyLock}; +use std::{ + path::PathBuf, + sync::{Arc, LazyLock}, +}; use miniserve::{http::StatusCode, Content, Request, Response}; use serde::{Deserialize, Serialize}; use tokio::{ - join, + fs, join, sync::{mpsc, oneshot}, + task::JoinSet, }; async fn index(_req: Request) -> Response { @@ -17,19 +21,36 @@ struct Messages { messages: Vec, } -async fn query_chat(messages: &Arc>) -> Vec { - type Payload = (Arc>, oneshot::Sender>); - static SENDER: LazyLock> = LazyLock::new(|| { - let (tx, mut rx) = mpsc::channel::(1024); - tokio::spawn(async move { - let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]); - while let Some((messages, responder)) = rx.recv().await { - let response = chatbot.query_chat(&messages).await; - responder.send(response).unwrap(); - } - }); - tx +async fn load_docs(paths: Vec) -> Vec { + let mut doc_futs = paths + .into_iter() + .map(fs::read_to_string) + .collect::>(); + let mut docs = Vec::new(); + while let Some(result) = doc_futs.join_next().await { + docs.push(result.unwrap().unwrap()); + } + docs +} + +type Payload = (Arc>, oneshot::Sender>); + +fn chatbot_thread() -> mpsc::Sender { + let (tx, mut rx) = mpsc::channel::(1024); + tokio::spawn(async move { + let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]); + while let Some((messages, responder)) = rx.recv().await { + let doc_paths = chatbot.retrieval_documents(&messages); + let docs = load_docs(doc_paths).await; + let response = chatbot.query_chat(&messages, &docs).await; + responder.send(response).unwrap(); + } }); + tx +} + +async fn query_chat(messages: &Arc>) -> Vec { + static SENDER: LazyLock> = LazyLock::new(chatbot_thread); let (tx, rx) = oneshot::channel(); SENDER.send((Arc::clone(messages), tx)).await.unwrap();