From c5811635125d50147ce353897e6e9225a1745df1 Mon Sep 17 00:00:00 2001 From: David Pacheco Date: Wed, 6 Dec 2023 15:59:41 -0800 Subject: [PATCH] add a helper for querying the database in batches (#4632) --- nexus/db-queries/src/db/datastore/dns.rs | 29 ++-- nexus/db-queries/src/db/mod.rs | 3 +- nexus/db-queries/src/db/pagination.rs | 184 +++++++++++++++++++++++ 3 files changed, 196 insertions(+), 20 deletions(-) diff --git a/nexus/db-queries/src/db/datastore/dns.rs b/nexus/db-queries/src/db/datastore/dns.rs index cfd25d6a4f..552ad31487 100644 --- a/nexus/db-queries/src/db/datastore/dns.rs +++ b/nexus/db-queries/src/db/datastore/dns.rs @@ -15,6 +15,7 @@ use crate::db::model::DnsZone; use crate::db::model::Generation; use crate::db::model::InitialDnsGroup; use crate::db::pagination::paginated; +use crate::db::pagination::Paginator; use crate::db::pool::DbConnection; use crate::db::TransactionError; use async_bb8_diesel::AsyncConnection; @@ -242,9 +243,8 @@ impl DataStore { let mut zones = Vec::with_capacity(dns_zones.len()); for zone in dns_zones { let mut zone_records = Vec::new(); - let mut marker = None; - - loop { + let mut paginator = Paginator::new(batch_size); + while let Some(p) = paginator.next() { debug!(log, "listing DNS names for zone"; "dns_zone_id" => zone.id.to_string(), "dns_zone_name" => &zone.zone_name, @@ -252,25 +252,16 @@ impl DataStore { "found_so_far" => zone_records.len(), "batch_size" => batch_size.get(), ); - let pagparams = DataPageParams { - marker: marker.as_ref(), - direction: dropshot::PaginationOrder::Ascending, - limit: batch_size, - }; let names_batch = self - .dns_names_list(opctx, zone.id, version.version, &pagparams) + .dns_names_list( + opctx, + zone.id, + version.version, + &p.current_pagparams(), + ) .await?; - let done = names_batch.len() - < usize::try_from(batch_size.get()).unwrap(); - if let Some((last_name, _)) = names_batch.last() { - marker = Some(last_name.clone()); - } else { - assert!(done); - } + paginator = p.found_batch(&names_batch, &|(n, _)| n.clone()); zone_records.extend(names_batch.into_iter()); - if done { - break; - } } debug!(log, "found all DNS names for zone"; diff --git a/nexus/db-queries/src/db/mod.rs b/nexus/db-queries/src/db/mod.rs index e6b8743e94..924eab363f 100644 --- a/nexus/db-queries/src/db/mod.rs +++ b/nexus/db-queries/src/db/mod.rs @@ -21,7 +21,8 @@ pub(crate) mod error; mod explain; pub mod fixed_data; pub mod lookup; -mod pagination; +// Public for doctests. +pub mod pagination; mod pool; // This is marked public because the error types are used elsewhere, e.g., in // sagas. diff --git a/nexus/db-queries/src/db/pagination.rs b/nexus/db-queries/src/db/pagination.rs index dd7daab14f..4fc1cf5966 100644 --- a/nexus/db-queries/src/db/pagination.rs +++ b/nexus/db-queries/src/db/pagination.rs @@ -16,6 +16,7 @@ use diesel::AppearsOnTable; use diesel::Column; use diesel::{ExpressionMethods, QueryDsl}; use omicron_common::api::external::DataPageParams; +use std::num::NonZeroU32; // Shorthand alias for "the SQL type of the whole table". type TableSqlType = ::SqlType; @@ -169,6 +170,145 @@ where } } +/// Helper for querying a large number of records from the database in batches +/// +/// Without this helper: a typical way to perform paginated queries would be to +/// invoke some existing "list" function in the datastore that itself is +/// paginated. Such functions accept a `pagparams: &DataPageParams` argument +/// that uses a marker to identify where the next page of results starts. For +/// the first call, the marker inside `pagparams` is `None`. For subsequent +/// calls, it's typically some field from the last item returned in the previous +/// page. You're finished when you get a result set smaller than the batch +/// size. +/// +/// This helper takes care of most of the logic for you. To use this, you first +/// create a `Paginator` with a specific batch_size. Then you call `next()` in +/// a loop. Each iteration will provide you with a `DataPageParams` to use to +/// call your list function. When you've fetched the next page, you have to +/// let the helper look at it to determine if there's another page to fetch and +/// what marker to use. +/// +/// ## Example +/// +/// ``` +/// use nexus_db_queries::db::pagination::Paginator; +/// use omicron_common::api::external::DataPageParams; +/// +/// let batch_size = std::num::NonZeroU32::new(3).unwrap(); +/// +/// // Assume you've got an existing paginated "list items" function. +/// // This simple implementation returns a few full batches, then a partial +/// // batch. +/// type Marker = u32; +/// type Item = u32; +/// let do_query = |pagparams: &DataPageParams<'_, Marker> | { +/// match pagparams.marker { +/// None => (0..batch_size.get()).collect(), +/// Some(x) if *x < 2 * batch_size.get() => (x+1..x+1+batch_size.get()).collect(), +/// Some(x) => vec![*x + 1], +/// } +/// }; +/// +/// // This closure translates from one of the returned item to the field in +/// // that item that servers as the marker. This example is contrived. +/// let item2marker: &dyn Fn(&Item) -> Marker = &|u: &u32| *u; +/// +/// let mut all_records = Vec::new(); +/// let mut paginator = Paginator::new(batch_size); +/// while let Some(p) = paginator.next() { +/// let records_batch = do_query(&p.current_pagparams()); +/// paginator = p.found_batch(&records_batch, item2marker); +/// all_records.extend(records_batch.into_iter()); +/// } +/// +/// // Results are in `all_records`. +/// assert_eq!(all_records, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +/// ``` +/// +/// ## Design notes +/// +/// The separation of `Paginator` and `PaginatorHelper` is aimed at making it +/// harder to misuse this interface. We could skip the helper altogether and +/// just have `Paginator::next()` return the DatePageParams directly. But you'd +/// still need a `Paginator::found_batch()`. And it would be easy to forget to +/// call this, leading to an infinite loop at runtime. To avoid this mistake, +/// `Paginator::next()` consumes `self`. You can't get another `Paginator` back +/// until you use `PaginatorHelper::found_batch()`. That also consumes `self` +/// so that you can't keep using the old `DataPageParams`. +pub struct Paginator { + batch_size: NonZeroU32, + state: PaginatorState, +} + +impl Paginator { + pub fn new(batch_size: NonZeroU32) -> Paginator { + Paginator { batch_size, state: PaginatorState::Initial } + } + + pub fn next(self) -> Option> { + match self.state { + PaginatorState::Initial => Some(PaginatorHelper { + batch_size: self.batch_size, + marker: None, + }), + PaginatorState::Middle { marker } => Some(PaginatorHelper { + batch_size: self.batch_size, + marker: Some(marker), + }), + PaginatorState::Done => None, + } + } +} + +enum PaginatorState { + Initial, + Middle { marker: N }, + Done, +} + +pub struct PaginatorHelper { + batch_size: NonZeroU32, + marker: Option, +} + +impl PaginatorHelper { + /// Returns the `DatePageParams` to use to fetch the next page of results + pub fn current_pagparams(&self) -> DataPageParams<'_, N> { + DataPageParams { + marker: self.marker.as_ref(), + direction: dropshot::PaginationOrder::Ascending, + limit: self.batch_size, + } + } + + /// Report a page of results + /// + /// This function looks at the returned results to determine whether we've + /// finished iteration or whether we need to fetch another page (and if so, + /// this determines the marker for the next fetch operation). + /// + /// This function returns a `Paginator` used to make the next request. See + /// the example on `Paginator` for usage. + pub fn found_batch( + self, + batch: &[T], + item2marker: &dyn Fn(&T) -> N, + ) -> Paginator { + let state = + if batch.len() < usize::try_from(self.batch_size.get()).unwrap() { + PaginatorState::Done + } else { + // self.batch_size is non-zero, so if we got at least that many + // items, then there's at least one. + let last = batch.iter().last().unwrap(); + let marker = item2marker(last); + PaginatorState::Middle { marker } + }; + + Paginator { batch_size: self.batch_size, state } + } +} + #[cfg(test)] mod test { use super::*; @@ -433,4 +573,48 @@ mod test { let _ = db.cleanup().await; logctx.cleanup_successful(); } + + #[test] + fn test_paginator() { + // The doctest exercises a basic case for Paginator. Here we test some + // edge cases. + let batch_size = std::num::NonZeroU32::new(3).unwrap(); + + type Marker = u32; + #[derive(Debug, PartialEq, Eq)] + struct Item { + value: String, + marker: Marker, + } + + let do_list = + |query: &dyn Fn(&DataPageParams<'_, Marker>) -> Vec| { + let mut all_records = Vec::new(); + let mut paginator = Paginator::new(batch_size); + while let Some(p) = paginator.next() { + let records_batch = query(&p.current_pagparams()); + paginator = + p.found_batch(&records_batch, &|i: &Item| i.marker); + all_records.extend(records_batch.into_iter()); + } + all_records + }; + + fn mkitem(v: u32) -> Item { + Item { value: v.to_string(), marker: v } + } + + // Trivial case: first page is empty + assert_eq!(Vec::::new(), do_list(&|_| Vec::new())); + + // Exactly one batch-size worth of items + // (exercises the cases where the last non-empty batch is full, and + // where any batch is empty) + let my_query = + |pagparams: &DataPageParams<'_, Marker>| match &pagparams.marker { + None => (0..batch_size.get()).map(mkitem).collect(), + Some(_) => Vec::new(), + }; + assert_eq!(vec![mkitem(0), mkitem(1), mkitem(2)], do_list(&my_query)); + } }