Skip to content

Commit

Permalink
feat: Implement concurrency limits
Browse files Browse the repository at this point in the history
Also rename runner builder methods.
  • Loading branch information
Flix committed Dec 28, 2023
1 parent 1b9a85e commit f5d2bde
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn main() -> Result<()> {

// Start the job runner to execute jobs from the messages in the queue in the
// database.
let job_runner = JobRunner::new(db.clone()).set_context("cats").run::<JobRegistry>();
let job_runner = JobRunner::new(db.clone()).with_context("cats").run::<JobRegistry>();

// Spawn new jobs via a message on the database queue.
let job_id = JobRegistry::Greet.builder().spawn(&db).await?;
Expand Down
2 changes: 1 addition & 1 deletion examples/error_handling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<()> {
let error_received = Arc::new(AtomicBool::new(false));
let err_received = error_received.clone();
let job_runner = JobRunner::new(db.clone())
.set_error_handler(move |_err| {
.with_error_handler(move |_err| {
err_received.store(true, Ordering::SeqCst);
})
.run::<JobRegistry>();
Expand Down
13 changes: 11 additions & 2 deletions src/job.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Provider for job handlers.
use std::sync::Arc;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};

use serde::{de::DeserializeOwned, Serialize};
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -113,17 +116,23 @@ impl CurrentJob {
}

/// Job running function that handles retries as well etc.
pub(crate) fn run(mut self, mut function: JobFunctionType) -> JoinHandle<Result<(), Error>> {
pub(crate) fn run(
mut self,
mut function: JobFunctionType,
currently_running: Arc<AtomicUsize>,
) -> JoinHandle<Result<(), Error>> {
self.keep_alive = Some(Self::keep_alive(self.db.clone(), self.id).into());

let span = tracing::debug_span!("job-run");
currently_running.fetch_add(1, Ordering::Relaxed);
tokio::task::spawn(
async move {
let id = self.id;
let db = self.db.clone();

tracing::trace!("Starting job with ID {id}.");
let res = function(self).await;
currently_running.fetch_sub(1, Ordering::Relaxed);

// Handle the job's error
if let Err(err) = res {
Expand Down
79 changes: 64 additions & 15 deletions src/runner.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
//! Connector to the database which runs code based on the messages and their
//! type.
use std::{fmt::Debug, sync::Arc, time::Duration};
use std::{
fmt::Debug,
ops::Range,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
thread::available_parallelism,
time::Duration,
};

use bonsaidb::core::{
async_trait::async_trait,
Expand Down Expand Up @@ -34,6 +43,9 @@ pub struct JobRunner<DB> {
error_handler: Option<ErrorHandler>,
/// Outside context type-map to provide resources to the jobs.
context: Context,
/// Concurrency limits, a range from minimum to maximum concurrent jobs to
/// be targeted in the execution queue.
concurrency: Range<u32>,
}

impl<DB> JobRunner<DB>
Expand All @@ -42,12 +54,15 @@ where
{
/// Create a new job runner on this database.
pub fn new(db: DB) -> Self {
Self { db, error_handler: None, context: Context::new() }
let concurrency = available_parallelism()
.map(|num_cpus| usize::from(num_cpus) as u32 / 2..usize::from(num_cpus) as u32 * 2)
.unwrap_or(3_u32..8_u32);
Self { db, error_handler: None, context: Context::new(), concurrency }
}

/// Set the error handler callback to be called when jobs return an error.
#[must_use]
pub fn set_error_handler<F>(mut self, handler: F) -> Self
pub fn with_error_handler<F>(mut self, handler: F) -> Self
where
F: Fn(Box<dyn std::error::Error + Send + Sync>) + Send + Sync + 'static,
{
Expand All @@ -57,11 +72,18 @@ where

/// Add context to the runner. Only one instance per type can be inserted!
#[must_use]
pub fn set_context<C: Clone + Send + Sync + 'static>(mut self, context: C) -> Self {
pub fn with_context<C: Clone + Send + Sync + 'static>(mut self, context: C) -> Self {
self.context.insert(context);
self
}

/// Set the concurrency limits.
#[must_use]
pub fn with_concurrency_limits(mut self, min_concurrent: u32, max_concurrent: u32) -> Self {
self.concurrency = min_concurrent..max_concurrent;
self
}

/// Spawn and run the daemon for processing messages/jobs in the background.
/// Keep this handle as long as you want jobs to be executed in the
/// background! You can also use and await the handle like normal
Expand All @@ -75,6 +97,7 @@ where
db: Arc::new(self.db),
error_handler: self.error_handler,
context: Arc::new(self.context),
concurrency: self.concurrency,
};
tokio::task::spawn(internal_runner.job_queue::<REG>()).into()
}
Expand All @@ -86,6 +109,7 @@ impl<DB: Debug> Debug for JobRunner<DB> {
.field("db", &self.db)
.field("error_handler", &"<err handler fn>")
.field("context", &self.context)
.field("concurrency", &self.concurrency)
.finish()
}
}
Expand All @@ -98,6 +122,9 @@ struct InternalJobRunner<DB> {
error_handler: Option<ErrorHandler>,
/// Outside context type-map to provide resources to the jobs.
context: Arc<Context>,
/// Concurrency limits, a range from minimum to maximum concurrent jobs to
/// be targeted in the execution queue.
concurrency: Range<u32>,
}

impl<DB> Clone for InternalJobRunner<DB> {
Expand All @@ -106,6 +133,7 @@ impl<DB> Clone for InternalJobRunner<DB> {
db: self.db.clone(),
error_handler: self.error_handler.clone(),
context: self.context.clone(),
concurrency: self.concurrency.clone(),
}
}
}
Expand All @@ -118,8 +146,14 @@ where
async fn due_messages(
&self,
due_at: Timestamp,
limit: u32,
) -> Result<MappedDocuments<CollectionDocument<Message>, DueMessages>, BonsaiError> {
self.db.view::<DueMessages>().with_key_range(..due_at).query_with_collection_docs().await
self.db
.view::<DueMessages>()
.with_key_range(..due_at)
.limit(limit)
.query_with_collection_docs()
.await
}

/// Get the duration until the next message is due.
Expand Down Expand Up @@ -158,14 +192,25 @@ where
let subscriber = self.db.create_subscriber().await?;
subscriber.subscribe_to(&MQ_NOTIFY).await?;

let currently_running = Arc::new(AtomicUsize::new(0));
loop {
// Retrieve due messages
let now = OffsetDateTime::now_utc().unix_timestamp_nanos();
let messages = self.due_messages(now).await?;
tracing::trace!("Found {} due messages.", messages.len());

// Retrieve due messages if there is not enough running already
let running = currently_running.load(Ordering::Relaxed) as u32;
#[allow(clippy::if_then_some_else_none)] // It is async.
let messages = if running < self.concurrency.start {
Some(self.due_messages(now, self.concurrency.end.saturating_sub(running)).await?)
} else {
None
};
tracing::trace!(
"Handling {} due messages.",
messages.as_ref().map_or(0, MappedDocuments::len)
);

// Execute jobs for the messages
for msg in &messages {
for msg in messages.iter().flatten() {
if let Some(job) = REG::from_name(&msg.document.contents.name) {
// Filter out messages with active dependencies
if let Some(dependency) = msg.document.contents.execute_after {
Expand All @@ -187,8 +232,8 @@ where
keep_alive: None,
};

// Dropping the handle to the running job.. Panics will not cause
let _jh = current_job.run(job.function());
// Dropping the handle to the running job.. Panics will not cause anything.
let _jh = current_job.run(job.function(), currently_running.clone());
}
} else {
tracing::trace!(
Expand All @@ -200,10 +245,13 @@ where

// Sleep until the next message is due or a notification comes in.
let next_due_in = self.next_message_due_in(now).await?;
tokio::time::timeout(next_due_in, subscriber.receiver().receive_async())
.await
.ok() // Timeout is not a failure
.transpose()?;
tokio::time::timeout(
next_due_in.max(Duration::from_millis(100)), // Wait at least 100 ms.
subscriber.receiver().receive_async(),
)
.await
.ok() // Timeout is not a failure
.transpose()?;
}
}
}
Expand All @@ -214,6 +262,7 @@ impl<DB: Debug> Debug for InternalJobRunner<DB> {
.field("db", &self.db)
.field("error_handler", &"<err handler fn>")
.field("context", &self.context)
.field("concurrency", &self.concurrency)
.finish()
}
}
Expand Down

0 comments on commit f5d2bde

Please sign in to comment.