Skip to content

Commit

Permalink
controllers/krate/search: Simplify auth_user_id code
Browse files Browse the repository at this point in the history
We can figure out the `auth_user_id` value at the point where we turn the `ListQueryParams` instance into `FilterParams`. This then allows us to get rid of the `OnceLock`, and remove the extra arguments from the `make_query()` fn.
  • Loading branch information
Turbo87 committed Dec 17, 2024
1 parent 396a730 commit d166871
Showing 1 changed file with 21 additions and 36 deletions.
57 changes: 21 additions & 36 deletions src/controllers/krate/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use diesel_async::{AsyncPgConnection, RunQueryDsl};
use diesel_full_text_search::*;
use http::request::Parts;
use std::ops::Deref;
use std::sync::OnceLock;
use tracing::Instrument;
use utoipa::IntoParams;

Expand Down Expand Up @@ -67,7 +66,7 @@ pub async fn list_crates(
use diesel::sql_types::Float;
use seek::*;

let filter_params: FilterParams = params.into();
let filter_params = FilterParams::from(params, &req, &mut conn).await?;
let sort = filter_params.sort.as_deref();

let selection = (
Expand All @@ -82,8 +81,7 @@ pub async fn list_crates(

let mut seek: Option<Seek> = None;
let mut query = filter_params
.make_query(&req, &mut conn)
.await?
.make_query()?
.inner_join(crate_downloads::table)
.left_join(recent_crate_downloads::table)
.left_join(default_versions::table)
Expand Down Expand Up @@ -181,7 +179,7 @@ pub async fn list_crates(
//
// If this becomes a problem in the future the crates count could be denormalized, at least
// for the filterless happy path.
let count_query = filter_params.make_query(&req, &mut conn).await?.count();
let count_query = filter_params.make_query()?.count();
let query = query.pages_pagination_with_count_query(pagination, count_query);
let span = info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates");
let data = query.load::<Record>(&mut conn).instrument(span).await?;
Expand All @@ -193,7 +191,7 @@ pub async fn list_crates(
data.into_iter().collect::<Vec<_>>(),
)
} else {
let count_query = filter_params.make_query(&req, &mut conn).await?.count();
let count_query = filter_params.make_query()?.count();
let query = query.pages_pagination_with_count_query(pagination, count_query);
let span = info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates");
let data = query.load::<Record>(&mut conn).instrument(span).await?;
Expand Down Expand Up @@ -317,15 +315,11 @@ impl ListQueryParams {
let include_yanked = self.include_yanked.as_ref();
include_yanked.map(|s| s == "yes").unwrap_or(true)
}

pub fn following(&self) -> bool {
self.following.is_some()
}
}

struct FilterParams {
search_params: ListQueryParams,
_auth_user_id: OnceLock<i32>,
auth_user_id: Option<i32>,
}

impl Deref for FilterParams {
Expand All @@ -336,34 +330,26 @@ impl Deref for FilterParams {
}
}

impl From<ListQueryParams> for FilterParams {
fn from(search_params: ListQueryParams) -> Self {
Self {
impl FilterParams {
async fn from(
search_params: ListQueryParams,
parts: &Parts,
conn: &mut AsyncPgConnection,
) -> AppResult<Self> {
let auth_user_id = match search_params.following {
Some(_) => Some(AuthCheck::default().check(parts, conn).await?.user_id()),
None => None,
};

Ok(Self {
search_params,
_auth_user_id: OnceLock::new(),
}
auth_user_id,
})
}
}

impl FilterParams {
async fn authed_user_id(&self, req: &Parts, conn: &mut AsyncPgConnection) -> AppResult<i32> {
if let Some(val) = self._auth_user_id.get() {
return Ok(*val);
}

let user_id = AuthCheck::default().check(req, conn).await?.user_id();

// This should not fail, because of the `get()` check above
let _ = self._auth_user_id.set(user_id);

Ok(user_id)
}

async fn make_query(
&self,
req: &Parts,
conn: &mut AsyncPgConnection,
) -> AppResult<crates::BoxedQuery<'_, diesel::pg::Pg>> {
fn make_query(&self) -> AppResult<crates::BoxedQuery<'_, diesel::pg::Pg>> {
let mut query = crates::table.into_boxed();

if let Some(q_string) = &self.q_string {
Expand Down Expand Up @@ -443,8 +429,7 @@ impl FilterParams {
.filter(crate_owners::owner_id.eq(team_id)),
),
);
} else if self.following() {
let user_id = self.authed_user_id(req, conn).await?;
} else if let Some(user_id) = self.auth_user_id {
query = query.filter(
crates::id.eq_any(
follows::table
Expand Down

0 comments on commit d166871

Please sign in to comment.