Skip to content

Commit

Permalink
if any STT service errors, clear all queued workers
Browse files Browse the repository at this point in the history
  • Loading branch information
tazz4843 committed Dec 11, 2023
1 parent e5aa3ff commit 953669a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
17 changes: 14 additions & 3 deletions scripty_stt/src/load_balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ impl LoadBalancer {
let t3 = this.clone();
tokio::spawn(async move {
loop {
if let Ok(_) = purge_rx.recv_async().await {
if purge_rx.recv_async().await.is_ok() {
t3.queued_workers.lock().clear();
// request the queue be refilled
if let Err(_) = t3.new_worker_tx.send_async(()).await {
if t3.new_worker_tx.send_async(()).await.is_err() {
break error!(
"error sending new worker request: all client queues dropped"
);
Expand Down Expand Up @@ -251,6 +251,8 @@ pub struct LoadBalancedStream {
msg_rx_transmit_handle: Sender<ServerToClientMessage>,
// keep this field that way there's always one receiver
_msg_rx: Receiver<ServerToClientMessage>,

purge_tx: flume::Sender<()>,
}

impl LoadBalancedStream {
Expand All @@ -273,6 +275,7 @@ impl LoadBalancedStream {
self.peer_address,
self.msg_tx.clone(),
self.msg_rx_transmit_handle.subscribe(),
self.purge_tx.clone(),
)
.await
}
Expand Down Expand Up @@ -475,6 +478,7 @@ impl LoadBalancedStream {

let waiting_for_new_stream = Arc::new(AtomicBool::new(false));
let wfns2 = Arc::clone(&waiting_for_new_stream);
let purge_tx2 = purge_tx.clone();
// error handling task
tokio::spawn(async move {
loop {
Expand All @@ -483,7 +487,7 @@ impl LoadBalancedStream {
wfns2.store(true, Ordering::Relaxed);

// immediately purge all queued workers as we have bad state
if let Err(_) = purge_tx.send_async(()).await {
if purge_tx2.send_async(()).await.is_err() {
error!("error sending purge request: all client queues dropped");
break;
}
Expand Down Expand Up @@ -516,6 +520,12 @@ impl LoadBalancedStream {
let _ = new_read_stream_tx.send(stream_read).await;
let _ = new_write_stream_tx.send(stream_write).await;
wfns2.store(false, Ordering::Relaxed);

// purge again to ensure we clear out any queued workers that were on bad streams
if purge_tx2.send_async(()).await.is_err() {
error!("error sending purge request: all client queues dropped");
break;
}
}
});

Expand Down Expand Up @@ -546,6 +556,7 @@ impl LoadBalancedStream {
msg_tx: client_to_server_tx,
msg_rx_transmit_handle: server_to_client_tx,
_msg_rx: server_to_client_rx,
purge_tx,
})
}
}
Expand Down
16 changes: 11 additions & 5 deletions scripty_stt/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ pub struct Stream {
rx: Receiver<ServerToClientMessage>,
peer_address: SocketAddr,
session_id: Uuid,

purge_tx: flume::Sender<()>,
}

impl Stream {
pub(crate) async fn new(
peer_address: SocketAddr,
tx: Sender<ClientToServerMessage>,
mut rx: Receiver<ServerToClientMessage>,
purge_tx: flume::Sender<()>,
) -> Result<Self, ModelError> {
let session_id = Uuid::new_v4();
debug!(%session_id, %peer_address, "initializing stts stream to peer");
Expand Down Expand Up @@ -61,6 +64,7 @@ impl Stream {
rx,
peer_address,
session_id,
purge_tx,
})
}
Ok(false) => {
Expand All @@ -75,7 +79,7 @@ impl Stream {
}

pub fn feed_audio(&self, data: Vec<i16>) -> Result<(), ModelError> {
debug!(%self.session_id, "feeding audio to stts");
debug!(%self.session_id, %self.peer_address, "feeding audio to stts");
self.tx
.send(ClientToServerMessage::AudioData(AudioData {
data,
Expand All @@ -90,7 +94,7 @@ impl Stream {
verbose: bool,
translate: bool,
) -> Result<String, ModelError> {
debug!(%self.session_id, "getting result from stts");
debug!(%self.session_id, %self.peer_address, "getting result from stts");
// send the finalize message
self.tx
.send(ClientToServerMessage::FinalizeStreaming(
Expand All @@ -106,12 +110,13 @@ impl Stream {
while let Ok(next) = self.rx.recv().await {
if let ServerToClientMessage::SttResult(SttSuccess { id, result }) = next {
if id == self.session_id {
debug!(%self.session_id, "got result from stts");
debug!(%self.session_id, %self.peer_address, "got result from stts");
return Ok(result);
}
} else if let ServerToClientMessage::SttError(SttError { id, error }) = next {
if id == self.session_id {
debug!(%self.session_id, "got error from stts");
debug!(%self.session_id, %self.peer_address, "got error from stts");
self.purge_tx.send_async(()).await.ok();
return Err(ModelError::SttsServer(error));
}
}
Expand All @@ -122,7 +127,8 @@ impl Stream {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => Err(e),
Err(_) => {
warn!(%self.session_id, "timed out waiting for result");
warn!(%self.session_id, %self.peer_address, "timed out waiting for result");
self.purge_tx.send_async(()).await.ok();
Err(ModelError::TimedOutWaitingForResult)
}
}
Expand Down

0 comments on commit 953669a

Please sign in to comment.