From dae2a6583c8e42e230046991ace54f7506370a52 Mon Sep 17 00:00:00 2001 From: Will Crichton Date: Mon, 12 Aug 2024 15:09:45 -0700 Subject: [PATCH] Message passing solution --- crates/server/src/main.rs | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 9fa73ce..a94f6a8 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -1,8 +1,11 @@ -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use miniserve::{http::StatusCode, Content, Request, Response}; use serde::{Deserialize, Serialize}; -use tokio::join; +use tokio::{ + join, + sync::{mpsc, oneshot}, +}; async fn index(_req: Request) -> Response { let content = include_str!("../index.html").to_string(); @@ -14,6 +17,25 @@ struct Messages { messages: Vec, } +async fn query_chat(messages: &Arc>) -> Vec { + type Payload = (Arc>, oneshot::Sender>); + static SENDER: LazyLock> = LazyLock::new(|| { + let (tx, mut rx) = mpsc::channel::(1024); + tokio::spawn(async move { + let mut chatbot = chatbot::Chatbot::new(vec![":-)".into(), "^^".into()]); + while let Some((messages, responder)) = rx.recv().await { + let response = chatbot.query_chat(&messages).await; + responder.send(response).unwrap(); + } + }); + tx + }); + + let (tx, rx) = oneshot::channel(); + SENDER.send((Arc::clone(messages), tx)).await.unwrap(); + rx.await.unwrap() +} + async fn chat(req: Request) -> Response { let Request::Post(body) = req else { return Err(StatusCode::METHOD_NOT_ALLOWED); @@ -23,12 +45,7 @@ async fn chat(req: Request) -> Response { }; let messages = Arc::new(data.messages); - let messages_ref = Arc::clone(&messages); - let (i, responses) = join!( - chatbot::gen_random_number(), - tokio::spawn(async move { chatbot::query_chat(&messages_ref).await }) - ); - let mut responses = responses.unwrap(); + let (i, mut responses) = join!(chatbot::gen_random_number(), query_chat(&messages)); let response = responses.remove(i % responses.len()); data.messages = Arc::into_inner(messages).unwrap();