Skip to content

Commit

Permalink
Batch loading for tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
andriidemus committed Jan 7, 2025
1 parent aa80b66 commit 69f318c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 9 deletions.
18 changes: 10 additions & 8 deletions src/domain/task-system/services/src/task_agent_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,21 @@ impl TaskAgentImpl {
})
.try_collect()
.await?;
let batch_size = running_task_ids.len();

for running_task_id in &running_task_ids {
// TODO: batch loading of tasks
let mut task = Task::load(*running_task_id, task_event_store.as_ref())
.await
.int_err()?;
let tasks = Task::load_multi(running_task_ids, task_event_store.as_ref())
.await
.int_err()?;

for task in tasks {
let mut t = task.int_err()?;

// Requeue
task.requeue(self.time_source.now()).int_err()?;
task.save(task_event_store.as_ref()).await.int_err()?;
t.requeue(self.time_source.now()).int_err()?;
t.save(task_event_store.as_ref()).await.int_err()?;
}

processed_running_tasks += running_task_ids.len();
processed_running_tasks += batch_size;
}

Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ impl EventStore<FlowState> for PostgresFlowEventStore {
let event = serde_json::from_value::<FlowEvent>(event_row.event_payload)
.map_err(|e| sqlx::Error::Decode(Box::new(e)))?;

Ok((FlowID::try_from(event_row.flow_id).unwrap(), // todo: handle error
Ok((FlowID::try_from(event_row.flow_id).unwrap(), // ids are always > 0
EventID::new(event_row.event_id),
event))
})
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions src/infra/task-system/postgres/src/postgres_task_event_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ impl EventStore<TaskState> for PostgresTaskEventStore {
})
}

fn get_events_multi(&self, queries: Vec<TaskID>) -> MultiEventStream<TaskID, TaskEvent> {
let task_ids: Vec<i64> = queries.iter().map(|id| (*id).try_into().unwrap()).collect();

Box::pin(async_stream::stream! {
let mut tr = self.transaction.lock().await;
let connection_mut = tr
.connection_mut()
.await?;

let mut query_stream = sqlx::query!(
r#"
SELECT task_id, event_id, event_payload
FROM task_events
WHERE task_id = ANY($1)
ORDER BY event_id ASC
"#,
&task_ids,
).try_map(|event_row| {
let event = serde_json::from_value::<TaskEvent>(event_row.event_payload)
.map_err(|e| sqlx::Error::Decode(Box::new(e)))?;

Ok((TaskID::try_from(event_row.task_id).unwrap(), // ids are always > 0
EventID::new(event_row.event_id),
event))
})
.fetch(connection_mut)
.map_err(|e| GetEventsError::Internal(e.int_err()));

while let Some((task_id, event_id, event)) = query_stream.try_next().await? {
yield Ok((task_id, event_id, event));
}
})
}

async fn save_events(
&self,
task_id: &TaskID,
Expand Down

0 comments on commit 69f318c

Please sign in to comment.