Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract crates_io_session crate #10267

Merged
merged 3 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ aws-ip-ranges = "=0.962.0"
aws-sdk-cloudfront = "=1.57.0"
aws-sdk-sqs = "=1.51.0"
axum = { version = "=0.7.9", features = ["macros", "matched-path"] }
axum-extra = { version = "=0.9.6", features = ["cookie-signed", "erased-json", "query", "typed-header"] }
axum-extra = { version = "=0.9.6", features = ["erased-json", "query", "typed-header"] }
base64 = "=0.22.1"
bigdecimal = { version = "=0.4.7", features = ["serde"] }
bon = "=3.3.1"
Expand All @@ -63,6 +63,7 @@ crates_io_github = { path = "crates/crates_io_github" }
crates_io_index = { path = "crates/crates_io_index" }
crates_io_markdown = { path = "crates/crates_io_markdown" }
crates_io_pagerduty = { path = "crates/crates_io_pagerduty" }
crates_io_session = { path = "crates/crates_io_session" }
crates_io_tarball = { path = "crates/crates_io_tarball" }
crates_io_team_repo = { path = "crates/crates_io_team_repo" }
crates_io_worker = { path = "crates/crates_io_worker" }
Expand Down
17 changes: 17 additions & 0 deletions crates/crates_io_session/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "crates_io_session"
version = "0.0.0"
license = "MIT OR Apache-2.0"
edition = "2021"

[lints]
workspace = true

[dependencies]
axum = { version = "=0.7.9", features = ["macros"] }
axum-extra = { version = "=0.9.6", features = ["cookie-signed"] }
base64 = "=0.22.1"
cookie = { version = "=0.18.1", features = ["secure"] }
parking_lot = "=0.12.3"

[dev-dependencies]
24 changes: 5 additions & 19 deletions src/middleware/session.rs → crates/crates_io_session/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
use crate::controllers::util::RequestPartsExt;
use axum::extract::{Extension, FromRequestParts, Request};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum_extra::extract::SignedCookieJar;
use base64::{engine::general_purpose, Engine};
use cookie::time::Duration;
use cookie::{Cookie, SameSite};
use derive_more::Deref;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;

static COOKIE_NAME: &str = "cargo_session";
static MAX_AGE_DAYS: i64 = 90;

#[derive(Clone, FromRequestParts, Deref)]
#[derive(Clone, FromRequestParts)]
#[from_request(via(Extension))]
pub struct SessionExtension(Arc<RwLock<Session>>);

Expand All @@ -24,18 +22,18 @@ impl SessionExtension {
}

pub fn get(&self, key: &str) -> Option<String> {
let session = self.read();
let session = self.0.read();
session.data.get(key).cloned()
}

pub fn insert(&self, key: String, value: String) -> Option<String> {
let mut session = self.write();
let mut session = self.0.write();
session.dirty = true;
session.data.insert(key, value)
}

pub fn remove(&self, key: &str) -> Option<String> {
let mut session = self.write();
let mut session = self.0.write();
session.dirty = true;
session.data.remove(key)
}
Expand All @@ -54,7 +52,7 @@ pub async fn attach_session(jar: SignedCookieJar, mut req: Request, next: Next)
let response = next.run(req).await;

// Check if the session data was mutated
let session = session.read();
let session = session.0.read();
if session.dirty {
// Return response with additional `Set-Cookie` header
let encoded = encode(&session.data);
Expand Down Expand Up @@ -83,18 +81,6 @@ impl Session {
}
}

pub trait RequestSession {
fn session(&self) -> &SessionExtension;
}

impl<T: RequestPartsExt> RequestSession for T {
fn session(&self) -> &SessionExtension {
self.extensions()
.get::<SessionExtension>()
.expect("missing cookie session")
}
}

pub fn decode(cookie: Cookie<'_>) -> HashMap<String, String> {
let mut ret = HashMap::new();
let bytes = general_purpose::STANDARD
Expand Down
11 changes: 6 additions & 5 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::controllers;
use crate::controllers::util::RequestPartsExt;
use crate::middleware::log_request::RequestLogExt;
use crate::middleware::session::RequestSession;
use crate::models::token::{CrateScope, EndpointScope};
use crate::models::{ApiToken, User};
use crate::util::errors::{
account_locked, forbidden, internal, AppResult, InsecurelyGeneratedTokenRevoked,
};
use crate::util::token::HashedToken;
use chrono::Utc;
use crates_io_session::SessionExtension;
use diesel_async::AsyncPgConnection;
use http::header;
use http::request::Parts;
Expand Down Expand Up @@ -176,11 +176,12 @@ async fn authenticate_via_cookie(
parts: &Parts,
conn: &mut AsyncPgConnection,
) -> AppResult<Option<CookieAuthentication>> {
let user_id_from_session = parts
.session()
.get("user_id")
.and_then(|s| s.parse::<i32>().ok());
let session = parts
.extensions()
.get::<SessionExtension>()
.expect("missing cookie session");

let user_id_from_session = session.get("user_id").and_then(|s| s.parse::<i32>().ok());
let Some(id) = user_id_from_session else {
return Ok(None);
};
Expand Down
2 changes: 1 addition & 1 deletion src/controllers/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse};
use crate::app::AppState;
use crate::email::Emails;
use crate::middleware::log_request::RequestLogExt;
use crate::middleware::session::SessionExtension;
use crate::models::{NewUser, User};
use crate::schema::users;
use crate::util::diesel::is_read_only_error;
use crate::util::errors::{bad_request, server_error, AppResult};
use crate::views::EncodableMe;
use crates_io_github::GithubUser;
use crates_io_session::SessionExtension;

/// Begin authentication flow.
///
Expand Down
6 changes: 4 additions & 2 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub mod log_request;
pub mod normalize_path;
pub mod real_ip;
mod require_user_agent;
pub mod session;
mod static_or_continue;
mod update_metrics;

Expand Down Expand Up @@ -59,7 +58,10 @@ pub fn apply_axum_middleware(state: AppState, router: Router<()>) -> Router {
state.config.cargo_compat_status_code_config,
cargo_compat::middleware,
))
.layer(from_fn_with_state(state.clone(), session::attach_session))
.layer(from_fn_with_state(
state.clone(),
crates_io_session::attach_session,
))
.layer(from_fn_with_state(
state.clone(),
require_user_agent::require_user_agent,
Expand Down
3 changes: 1 addition & 2 deletions src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
//! `MockCookieUser` and `MockTokenUser` provide an `as_model` function which returns a reference
//! to the underlying database model value (`User` and `ApiToken` respectively).

use crate::middleware::session;
use crate::models::{ApiToken, CreatedApiToken, User};
use crate::tests::{
CategoryListResponse, CategoryResponse, CrateList, CrateResponse, GoodCrate, OwnerResp,
Expand Down Expand Up @@ -72,7 +71,7 @@ pub fn encode_session_header(session_key: &cookie::Key, user_id: i32) -> String
map.insert("user_id".into(), user_id.to_string());

// encode the map into a cookie value string
let encoded = session::encode(&map);
let encoded = crates_io_session::encode(&map);

// put the cookie into a signed cookie jar
let cookie = Cookie::build((cookie_name, encoded));
Expand Down