Skip to content

Commit

Permalink
use semaphore for max concurrent tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Sep 16, 2024
1 parent 4d04d47 commit 42b80ce
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
5 changes: 5 additions & 0 deletions bin/torii/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ struct Args {
/// Polling interval in ms
#[arg(long, default_value = "500")]
polling_interval: u64,

/// Max concurrent tasks
#[arg(long, default_value = "100")]
max_concurrent_tasks: usize,
}

#[tokio::main]
Expand Down Expand Up @@ -196,6 +200,7 @@ async fn main() -> anyhow::Result<()> {
provider.clone(),
processors,
EngineConfig {
max_concurrent_tasks: args.max_concurrent_tasks,
start_block: args.start_block,
events_chunk_size: args.events_chunk_size,
index_pending: args.index_pending,
Expand Down
44 changes: 26 additions & 18 deletions crates/torii/core/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::time::Duration;

use anyhow::Result;
use dojo_world::contracts::world::WorldContractReader;
use futures_util::future::try_join_all;
use hashlink::LinkedHashMap;
use starknet::core::types::{
BlockId, BlockTag, EmittedEvent, Event, EventFilter, Felt, MaybePendingBlockWithReceipts,
Expand All @@ -14,10 +13,10 @@ use starknet::core::types::{
};
use starknet::providers::Provider;
use starknet_crypto::poseidon_hash_many;
use tokio::sync::broadcast::Sender;
use tokio::sync::mpsc::Sender as BoundedSender;
use tokio::sync::{broadcast::Sender, mpsc::Sender as BoundedSender, Semaphore};
use tokio::time::sleep;
use tracing::{debug, error, info, trace, warn};
use tokio::task::JoinSet;

use crate::processors::event_message::EventMessageProcessor;
use crate::processors::{BlockProcessor, EventProcessor, TransactionProcessor};
Expand Down Expand Up @@ -51,6 +50,7 @@ pub struct EngineConfig {
pub start_block: u64,
pub events_chunk_size: u64,
pub index_pending: bool,
pub max_concurrent_tasks: usize,
}

impl Default for EngineConfig {
Expand All @@ -60,6 +60,7 @@ impl Default for EngineConfig {
start_block: 0,
events_chunk_size: 1024,
index_pending: true,
max_concurrent_tasks: 100,
}
}
}
Expand Down Expand Up @@ -440,34 +441,41 @@ impl<P: Provider + Send + Sync + std::fmt::Debug + 'static> Engine<P> {
}
}

// Process queued tasks in parallel
let tasks: Vec<_> = self.tasks.drain().map(|(task_id, events)| {
// We use a semaphore to limit the number of concurrent tasks
let semaphore = Arc::new(Semaphore::new(self.config.max_concurrent_tasks));

// Run all tasks concurrently
let mut set = JoinSet::new();
for (task_id, events) in self.tasks.drain() {
let db = self.db.clone();
let world = self.world.clone();
let processors = self.processors.clone();
let block_timestamp = data.blocks[&last_block];
let semaphore = semaphore.clone();

tokio::spawn(async move {
set.spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
let mut local_db = db.clone();
for (event_id, event) in events {
if let Some(processor) = processors.event.get(&event.keys[0]) {
debug!(target: LOG_TARGET, event_name = processor.event_key(), task_id = %task_id, "Processing parallelized event.");

if let Err(e) = processor
.process(&world, &mut local_db, last_block, block_timestamp, &event_id, &event)
.await
{
error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, task_id = %task_id, "Processing queued event.");
error!(target: LOG_TARGET, event_name = processor.event_key(), error = %e, task_id = %task_id, "Processing parallelized event.");
}
}
}
Ok::<_, anyhow::Error>(local_db)
})
}).collect();

// We wait for all tasks to complete processing
let results = try_join_all(tasks).await?;
for local_db in results {
// We merge the query queues of each task into the main db
self.db.merge(local_db?)?;
});
}

// Join all tasks
while let Some(result) = set.join_next().await {
let local_db = result??;
self.db.merge(local_db)?;
}

self.db.set_head(data.latest_block_number, None, None);
Expand Down Expand Up @@ -505,7 +513,7 @@ impl<P: Provider + Send + Sync + std::fmt::Debug + 'static> Engine<P> {
block_timestamp,
&event_id,
&event,
transaction_hash,
// transaction_hash,
)
.await?;
}
Expand Down Expand Up @@ -555,7 +563,7 @@ impl<P: Provider + Send + Sync + std::fmt::Debug + 'static> Engine<P> {
block_timestamp,
&event_id,
event,
*transaction_hash,
// *transaction_hash,
)
.await?;
}
Expand Down Expand Up @@ -615,7 +623,7 @@ impl<P: Provider + Send + Sync + std::fmt::Debug + 'static> Engine<P> {
block_timestamp: u64,
event_id: &str,
event: &Event,
transaction_hash: Felt,
// transaction_hash: Felt,
) -> Result<()> {
// self.db.store_event(event_id, event, transaction_hash, block_timestamp);
let event_key = event.keys[0];
Expand Down

0 comments on commit 42b80ce

Please sign in to comment.