Skip to content

Commit

Permalink
Merge pull request #7670 from Turbo87/axum-0.7
Browse files Browse the repository at this point in the history
Upgrade to axum 0.7
  • Loading branch information
Turbo87 authored Dec 7, 2023
2 parents e399f06 + 2da3d3e commit 9e17c20
Show file tree
Hide file tree
Showing 21 changed files with 466 additions and 236 deletions.
307 changes: 218 additions & 89 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 10 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ async-trait = "=0.1.74"
aws-credential-types = { version = "=1.0.3", features = ["hardcoded-credentials"] }
aws-ip-ranges = "=0.22.0"
aws-sdk-cloudfront = "=1.4.0"
axum = { version = "=0.6.20", features = ["headers", "macros", "matched-path"] }
axum-extra = { version = "=0.8.0", features = ["cookie-signed"] }
axum = { version = "=0.7.2", features = ["macros", "matched-path"] }
axum-extra = { version = "=0.9.0", features = ["cookie-signed", "typed-header"] }
base64 = "=0.21.5"
bigdecimal = "=0.4.2"
cargo-manifest = "=0.12.1"
Expand All @@ -62,7 +62,7 @@ crates_io_tarball = { path = "crates_io_tarball" }
crates_io_worker = { path = "crates_io_worker" }
chrono = { version = "=0.4.31", default-features = false, features = ["serde"] }
clap = { version = "=4.4.11", features = ["derive", "env", "unicode", "wrap_help"] }
cookie = { version = "=0.17.0", features = ["secure"] }
cookie = { version = "=0.18.0", features = ["secure"] }
crossbeam-channel = "=0.5.8"
dashmap = { version = "=5.5.3", features = ["raw-api"] }
derive_builder = "=0.12.0"
Expand All @@ -77,9 +77,11 @@ futures-channel = { version = "=0.3.29", default-features = false }
futures-util = "=0.3.29"
github-meta = "=0.11.0"
hex = "=0.4.3"
http = "=0.2.11"
http-body = "=0.4.5"
hyper = { version = "=0.14.27", features = ["backports", "client", "deprecated", "http1"] }
http = "=1.0.0"
http-body = "=1.0.0"
http-body-util = "=0.1.0"
hyper = { version = "=1.0.1", features = ["client", "http1"] }
hyper-util = { version = "=0.1.1", features = ["tokio", "server-auto", "http1"] }
indexmap = { version = "=2.1.0", features = ["serde"] }
indicatif = "=0.17.7"
ipnetwork = "=0.20.0"
Expand All @@ -98,7 +100,7 @@ reqwest = { version = "=0.11.22", features = ["gzip", "json"] }
scheduled-thread-pool = "=0.2.7"
secrecy = "=0.8.0"
semver = { version = "=1.0.20", features = ["serde"] }
sentry = { version = "=0.31.8", features = ["tracing", "tower", "tower-axum-matched-path", "tower-http"] }
sentry = { version = "=0.32.0", features = ["tracing", "tower", "tower-axum-matched-path", "tower-http"] }
serde = { version = "=1.0.193", features = ["derive"] }
serde_json = "=1.0.108"
sha2 = "=0.10.8"
Expand All @@ -109,7 +111,7 @@ thiserror = "=1.0.50"
tokio = { version = "=1.34.0", features = ["net", "signal", "io-std", "io-util", "rt-multi-thread", "macros"]}
toml = "=0.8.8"
tower = "=0.4.13"
tower-http = { version = "=0.4.4", features = ["add-extension", "fs", "catch-panic", "timeout", "compression-full"] }
tower-http = { version = "=0.5.0", features = ["add-extension", "fs", "catch-panic", "timeout", "compression-full"] }
tracing = "=0.1.40"
tracing-subscriber = { version = "=0.3.18", features = ["env-filter"] }
typomania = { version = "=0.1.2", default-features = false }
Expand Down
166 changes: 139 additions & 27 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ use crates_io::middleware::normalize_path::normalize_path;
use crates_io::{metrics::LogEncoder, util::errors::AppResult, App, Emails};
use std::{sync::Arc, time::Duration};

use axum::extract::Request;
use axum::ServiceExt;
use crates_io::github::RealGitHubClient;
use futures_util::future::FutureExt;
use hyper::body::Incoming;
use hyper_util::rt::TokioIo;
use prometheus::Encoder;
use reqwest::Client;
use std::io::{self, Write};
use std::io::Write;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tokio::signal::unix::{signal, SignalKind};
use tower::Layer;
use tokio::sync::watch;
use tower::{Layer, Service};

const CORE_THREADS: usize = 4;

Expand Down Expand Up @@ -56,37 +60,124 @@ fn main() -> anyhow::Result<()> {

let rt = builder.build().unwrap();

let make_service = axum_router.into_make_service_with_connect_info::<SocketAddr>();
let mut make_service = axum_router.into_make_service_with_connect_info::<SocketAddr>();

let (addr, server) = rt.block_on(async {
let socket_addr = (app.config.ip, app.config.port).into();
let server = hyper::Server::bind(&socket_addr).serve(make_service);
// to understand the following implementation,
// see https://github.com/tokio-rs/axum/blob/axum-v0.7.2/examples/graceful-shutdown/src/main.rs
// and https://github.com/tokio-rs/axum/blob/axum-v0.7.2/examples/serve-with-hyper/src/main.rs

// When the user configures PORT=0 the operating system will allocate a random unused port.
// This fetches that random port and uses it to display the correct url later.
let addr = server.local_addr();

let mut sig_int = signal(SignalKind::interrupt())?;
let mut sig_term = signal(SignalKind::terminate())?;
let server = server.with_graceful_shutdown(async move {
// Wait for either signal
tokio::select! {
_ = sig_int.recv().fuse() => {},
_ = sig_term.recv().fuse() => {},
// Block the main thread until the server has shutdown
rt.block_on(async {
// Create a `TcpListener` using tokio.
let listener = TcpListener::bind((app.config.ip, app.config.port)).await?;

let addr = listener.local_addr()?;

// Do not change this line! Removing the line or changing its contents in any way will break
// the test suite :)
info!("Listening at http://{addr}");

// Create a watch channel to track tasks that are handling connections and wait for them to
// complete.
let (close_tx, close_rx) = watch::channel(());

// Continuously accept new connections.
loop {
let (socket, remote_addr) = tokio::select! {
// Either accept a new connection...
result = listener.accept() => {
result.unwrap()
}
// ...or wait to receive a shutdown signal and stop the accept loop.
_ = shutdown_signal() => {
debug!("shutdown signal received, not accepting new connections");
break;
}
};

info!("Starting graceful shutdown");
});
debug!("connection {remote_addr} accepted");

// We don't need to call `poll_ready` because `IntoMakeServiceWithConnectInfo` is always
// ready.
let tower_service = make_service.call(remote_addr).await.unwrap();

// Clone the watch receiver and move it into the task.
let close_rx = close_rx.clone();

// Spawn a task to handle the connection. That way we can serve multiple connections
// concurrently.
tokio::spawn(async move {
// Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio.
// `TokioIo` converts between them.
let socket = TokioIo::new(socket);

// Hyper also has its own `Service` trait and doesn't use tower. We can use
// `hyper::service::service_fn` to create a hyper `Service` that calls our app through
// `tower::Service::call`.
let hyper_service =
hyper::service::service_fn(move |request: Request<Incoming>| {
// We have to clone `tower_service` because hyper's `Service` uses `&self` whereas
// tower's `Service` requires `&mut self`.
//
// We don't need to call `poll_ready` since `Router` is always ready.
tower_service.clone().call(request.map(axum::body::Body::new))
});

// `hyper_util::server::conn::auto::Builder` supports both http1 and http2 but doesn't
// support graceful so we have to use hyper directly and unfortunately pick between
// http1 and http2.
let conn = hyper::server::conn::http1::Builder::new()
.serve_connection(socket, hyper_service)
// `with_upgrades` is required for websockets.
.with_upgrades();

// `graceful_shutdown` requires a pinned connection.
let mut conn = std::pin::pin!(conn);

loop {
tokio::select! {
// Poll the connection. This completes when the client has closed the
// connection, graceful shutdown has completed, or we encounter a TCP error.
result = conn.as_mut() => {
if let Err(err) = result {
debug!("failed to serve connection: {err:#}");
}
break;
}
// Start graceful shutdown when we receive a shutdown signal.
//
// We use a loop to continue polling the connection to allow requests to finish
// after starting graceful shutdown. Our `Router` has `TimeoutLayer` so
// requests will finish after at most 30 seconds.
_ = shutdown_signal() => {
debug!("shutdown signal received, starting graceful connection shutdown");
conn.as_mut().graceful_shutdown();
}
}
}

debug!("connection {remote_addr} closed");

// Drop the watch receiver to signal to `main` that this task is done.
drop(close_rx);
});
}

Ok::<_, io::Error>((addr, server))
})?;
info!("Starting graceful shutdown");

// Do not change this line! Removing the line or changing its contents in any way will break
// the test suite :)
info!("Listening at http://{addr}");
// We only care about the watch receivers that were moved into the tasks so close the residual
// receiver.
drop(close_rx);

// Block the main thread until the server has shutdown
rt.block_on(server)?;
// Close the listener to stop accepting new connections.
drop(listener);

// Wait for all tasks to complete.
debug!("waiting for {} tasks to finish", close_tx.receiver_count());
close_tx.closed().await;

Ok::<(), anyhow::Error>(())
})?;

info!("Persisting remaining downloads counters");
match app.downloads_counter.persist_all_shards(&app) {
Expand All @@ -98,6 +189,27 @@ fn main() -> anyhow::Result<()> {
Ok(())
}

async fn shutdown_signal() {
let interrupt = async {
signal(SignalKind::interrupt())
.expect("failed to install signal handler")
.recv()
.await;
};

let terminate = async {
signal(SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};

tokio::select! {
_ = interrupt => {},
_ = terminate => {},
}
}

fn downloads_counter_thread(app: Arc<App>) {
let interval =
app.config.downloads_persist_interval / app.downloads_counter.shards_count() as u32;
Expand Down
11 changes: 5 additions & 6 deletions src/controllers/git.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
use axum::body::HttpBody;
use axum::extract::{ConnectInfo, Path};
use axum::extract::{ConnectInfo, Path, Request};
use axum::response::{IntoResponse, Response};
use http::header::HeaderName;
use http::request::Parts;
use http::{header, HeaderMap, HeaderValue, Request, StatusCode};
use http::{header, HeaderMap, HeaderValue, StatusCode};
use hyper::body::Buf;
use std::io::{BufRead, Read};
use std::net::SocketAddr;
use std::process::{Command, Stdio};

pub async fn http_backend<B: HttpBody>(
pub async fn http_backend(
Path(path): Path<String>,
ConnectInfo(remote_addr): ConnectInfo<SocketAddr>,
req: Request<B>,
req: Request,
) -> Result<Response, StatusCode> {
let path = format!("/{path}");

let (req, body) = req.into_parts();
let body = hyper::body::to_bytes(body)
let body = axum::body::to_bytes(body, usize::MAX)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Expand Down
2 changes: 1 addition & 1 deletion src/headers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use axum::headers::{Error, Header};
use axum_extra::headers::{Error, Header};
use http::header::{HeaderName, HeaderValue};

static X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
Expand Down
5 changes: 2 additions & 3 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@ use axum::middleware::{from_fn, from_fn_with_state};
use axum::Router;
use axum_extra::either::Either;
use axum_extra::middleware::option_layer;
use hyper::Body;
use std::time::Duration;
use tower::layer::util::Identity;
use tower_http::add_extension::AddExtensionLayer;
use tower_http::catch_panic::CatchPanicLayer;
use tower_http::compression::{CompressionLayer, CompressionLevel};
use tower_http::timeout::{RequestBodyTimeoutLayer, TimeoutBody, TimeoutLayer};
use tower_http::timeout::{RequestBodyTimeoutLayer, TimeoutLayer};

use crate::app::AppState;
use crate::Env;

pub fn apply_axum_middleware(state: AppState, router: Router<(), TimeoutBody<Body>>) -> Router {
pub fn apply_axum_middleware(state: AppState, router: Router<()>) -> Router {
let config = &state.config;
let env = config.env();

Expand Down
15 changes: 6 additions & 9 deletions src/middleware/balance_capacity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ use std::sync::atomic::{AtomicUsize, Ordering};

use crate::app::AppState;
use crate::middleware::log_request::RequestLogExt;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use http::{Request, StatusCode};
use http::StatusCode;

/// Handle a request when load exceeds a threshold
///
/// In report-only mode, log metadata is added but the request is still served. Otherwise,
/// the request is rejected with a service unavailable response.
async fn handle_high_load<B>(
async fn handle_high_load(
app_state: &AppState,
request: Request<B>,
next: Next<B>,
request: Request,
next: Next,
note: &str,
) -> Response {
let config = &app_state.config.balance_capacity;
Expand All @@ -40,11 +41,7 @@ async fn handle_high_load<B>(
}
}

pub async fn balance_capacity<B>(
app_state: AppState,
request: Request<B>,
next: Next<B>,
) -> Response {
pub async fn balance_capacity(app_state: AppState, request: Request, next: Next) -> Response {
let config = &app_state.config.balance_capacity;
let db_capacity = app_state.config.db.primary.pool_size;
let state = &app_state.balance_capacity;
Expand Down
10 changes: 5 additions & 5 deletions src/middleware/block_traffic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ use crate::app::AppState;
use crate::middleware::log_request::RequestLogExt;
use crate::middleware::real_ip::RealIp;
use crate::util::errors::RouteBlocked;
use axum::extract::{Extension, MatchedPath};
use axum::extract::{Extension, MatchedPath, Request};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use http::{HeaderMap, StatusCode};

pub async fn middleware<B>(
pub async fn middleware(
Extension(real_ip): Extension<RealIp>,
matched_path: Option<MatchedPath>,
state: AppState,
req: http::Request<B>,
next: Next<B>,
req: Request,
next: Next,
) -> Result<impl IntoResponse, Response> {
block_by_ip(&real_ip, &state, req.headers())?;
block_by_header(&state, &req)?;
Expand All @@ -29,7 +29,7 @@ pub async fn middleware<B>(
/// to `User-Agent=BLOCKED_UAS` and `BLOCKED_UAS` to `curl/7.54.0,cargo 1.36.0 (c4fcfb725 2019-05-15)`
/// to block requests from the versions of curl or Cargo specified (values are nonsensical examples).
/// Values of the headers must match exactly.
pub fn block_by_header<B>(state: &AppState, req: &http::Request<B>) -> Result<(), Response> {
pub fn block_by_header(state: &AppState, req: &Request) -> Result<(), Response> {
let blocked_traffic = &state.config.blocked_traffic;

for (header_name, blocked_values) in blocked_traffic {
Expand Down
Loading

0 comments on commit 9e17c20

Please sign in to comment.