Skip to content

Commit

Permalink
Make the rate limiter available to the GraphQL API handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
sandhose committed Jan 7, 2025
1 parent b9e0811 commit 2ebb59b
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 10 deletions.
1 change: 1 addition & 0 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ impl Options {
site_config.clone(),
password_manager.clone(),
url_builder.clone(),
limiter.clone(),
);

let state = {
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
Expand Down
22 changes: 18 additions & 4 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
Expand Down Expand Up @@ -53,7 +53,10 @@ use self::{
mutations::Mutation,
query::Query,
};
use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker};
use crate::{
impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker, Limiter,
RequesterFingerprint,
};

#[cfg(test)]
mod tests;
Expand All @@ -72,6 +75,7 @@ struct GraphQLState {
site_config: SiteConfig,
password_manager: PasswordManager,
url_builder: UrlBuilder,
limiter: Limiter,
}

#[async_trait]
Expand Down Expand Up @@ -104,6 +108,10 @@ impl state::State for GraphQLState {
&self.url_builder
}

fn limiter(&self) -> &Limiter {
&self.limiter
}

fn clock(&self) -> BoxClock {
let clock = SystemClock::default();
Box::new(clock)
Expand All @@ -126,6 +134,7 @@ pub fn schema(
site_config: SiteConfig,
password_manager: PasswordManager,
url_builder: UrlBuilder,
limiter: Limiter,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
Expand All @@ -134,6 +143,7 @@ pub fn schema(
site_config,
password_manager,
url_builder,
limiter,
};
let state: BoxState = Box::new(state);

Expand Down Expand Up @@ -303,6 +313,7 @@ pub async fn post(
cookie_jar: CookieJar,
content_type: Option<TypedHeader<ContentType>>,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
requester_fingerprint: RequesterFingerprint,
body: Body,
) -> Result<impl IntoResponse, RouteError> {
let body = body.into_data_stream();
Expand All @@ -329,6 +340,7 @@ pub async fn post(
MultipartOptions::default(),
)
.await?
.data(requester_fingerprint)
.data(requester); // XXX: this should probably return another error response?

let span = span_for_graphql_request(&request);
Expand All @@ -355,6 +367,7 @@ pub async fn get(
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
requester_fingerprint: RequesterFingerprint,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let token = authorization
Expand All @@ -371,8 +384,9 @@ pub async fn get(
)
.await?;

let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
.data(requester)
.data(requester_fingerprint);

let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
Expand Down
11 changes: 9 additions & 2 deletions crates/handlers/src/graphql/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
Expand All @@ -10,7 +10,7 @@ use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};

use crate::{graphql::Requester, passwords::PasswordManager};
use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};

#[async_trait::async_trait]
pub trait State {
Expand All @@ -22,6 +22,7 @@ pub trait State {
fn rng(&self) -> BoxRng;
fn site_config(&self) -> &SiteConfig;
fn url_builder(&self) -> &UrlBuilder;
fn limiter(&self) -> &Limiter;
}

pub type BoxState = Box<dyn State + Send + Sync + 'static>;
Expand All @@ -30,6 +31,8 @@ pub trait ContextExt {
fn state(&self) -> &BoxState;

fn requester(&self) -> &Requester;

fn requester_fingerprint(&self) -> RequesterFingerprint;
}

impl ContextExt for async_graphql::Context<'_> {
Expand All @@ -40,4 +43,8 @@ impl ContextExt for async_graphql::Context<'_> {
fn requester(&self) -> &Requester {
self.data_unchecked()
}

fn requester_fingerprint(&self) -> RequesterFingerprint {
*self.data_unchecked()
}
}
12 changes: 9 additions & 3 deletions crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
Expand Down Expand Up @@ -204,6 +204,8 @@ impl TestState {
let clock = Arc::new(MockClock::default());
let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42)));

let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();

let graphql_state = TestGraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(&policy_factory),
Expand All @@ -213,6 +215,7 @@ impl TestState {
clock: Arc::clone(&clock),
password_manager: password_manager.clone(),
url_builder: url_builder.clone(),
limiter: limiter.clone(),
};
let state: crate::graphql::BoxState = Box::new(graphql_state);

Expand All @@ -225,8 +228,6 @@ impl TestState {
shutdown_token.child_token(),
);

let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();

Ok(Self {
pool,
templates,
Expand Down Expand Up @@ -379,6 +380,7 @@ struct TestGraphQLState {
rng: Arc<Mutex<ChaChaRng>>,
password_manager: PasswordManager,
url_builder: UrlBuilder,
limiter: Limiter,
}

#[async_trait]
Expand Down Expand Up @@ -415,6 +417,10 @@ impl graphql::State for TestGraphQLState {
&self.site_config
}

fn limiter(&self) -> &Limiter {
&self.limiter
}

fn rng(&self) -> BoxRng {
let mut parent_rng = self.rng.lock().expect("Failed to lock RNG");
let rng = ChaChaRng::from_rng(&mut *parent_rng).expect("Failed to seed RNG");
Expand Down

0 comments on commit 2ebb59b

Please sign in to comment.