Skip to content

Commit

Permalink
update axum version in proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
paulgb committed Oct 7, 2024
1 parent d3e5cbb commit 27aaf91
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 410 deletions.
485 changes: 139 additions & 346 deletions Cargo.lock

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions plane/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ homepage = "https://plane.dev"
readme = "README.md"

[dependencies]
acme2-eab = "0.5.4"
acme2-eab = "0.5.7"
anyhow = "1.0.75"
async-stream = "0.3.5"
async-trait = "0.1.74"
axum = { version = "0.6.20", features = ["ws"] }
axum = { version = "0.7.7", features = ["ws"] }
bollard = "0.17.0"
bytes = "1.7.2"
chrono = { version = "0.4.31", features = ["serde"] }
Expand All @@ -25,12 +25,13 @@ data-encoding = "2.4.0"
dynamic-proxy = { path="../dynamic-proxy" }
futures-util = "0.3.29"
http-body = "0.4.6"
hyper = { version = "0.14.27", features = ["server"] }
hyper = { version = "1.4.1", features = ["server"] }
hyper-util = { version = "0.1.9", features = ["client", "client-legacy", "http1", "http2"] }
lru = "0.12.1"
openssl = "0.10.66"
pem = "3.0.2"
rand = "0.8.5"
reqwest = { version = "0.11.22", features = ["json", "rustls-tls"], default-features = false }
reqwest = { version = "0.12.8", features = ["json", "rustls-tls"], default-features = false }
rusqlite = { version = "0.31.0", features = ["bundled", "serde_json"] }
rustls-pemfile = "2.0.0"
rustls-pki-types = "1.0.0"
Expand All @@ -42,13 +43,13 @@ thiserror = "1.0.50"
time = "0.3.30"
tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread", "signal"] }
tokio-stream = { version="0.1.14", features=["sync"] }
tokio-tungstenite = { version = "0.20.1", features = ["rustls-tls-webpki-roots"] }
tower = "0.4.13"
tower-http = { version = "0.4.4", features = ["trace", "cors"] }
tokio-tungstenite = { version = "0.24.0", features = ["rustls-tls-webpki-roots"] }
tower = "0.5.1"
tower-http = { version = "0.6.1", features = ["trace", "cors"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json", "valuable"] }
trust-dns-server = "0.23.2"
tungstenite = "0.20.1"
tungstenite = "0.24.0"
url = { version="2.4.1", features=["serde"] }
valuable = { version = "0.1.0", features = ["derive"] }
x509-parser = "0.15.1"
6 changes: 3 additions & 3 deletions plane/plane-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dynamic-proxy = { path = "../../dynamic-proxy" }
futures-util = "0.3.29"
http = "1.1.0"
http-body-util = "0.1.2"
hyper = { version = "0.14.27", features = ["server"] }
hyper = { version = "1.4.1", features = ["server"] }
plane = { path = "../plane-dynamic", package = "plane-dynamic" }
plane-test-macro = { path = "plane-test-macro" }
reqwest = { version = "0.11.22", features = ["json", "rustls-tls"], default-features = false }
reqwest = { version = "0.12.8", features = ["json", "rustls-tls"], default-features = false }
serde = "1.0.210"
serde_json = "1.0.107"
thiserror = "1.0.50"
tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread", "signal"] }
tokio = { version = "1.33.0", features = ["macros", "net", "rt-multi-thread", "signal"] }
tokio-tungstenite = "0.24.0"
tracing = "0.1.40"
tracing-appender = "0.2.2"
Expand Down
3 changes: 1 addition & 2 deletions plane/plane-tests/tests/common/localhost_resolver.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use hyper::client::connect::dns::Name;
use reqwest::dns::{Resolve, Resolving};
use reqwest::dns::{Name, Resolve, Resolving};
use std::{future::ready, net::SocketAddr, sync::Arc};

/// A reqwest-compatible DNS resolver that resolves all requests to localhost.
Expand Down
4 changes: 2 additions & 2 deletions plane/plane-tests/tests/common/test_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl TestEnvironment {

pub async fn controller(&mut self) -> ControllerServer {
let db = self.db().await;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let url: Url = format!("http://{}", listener.local_addr().unwrap())
.parse()
.unwrap();
Expand Down Expand Up @@ -163,7 +163,7 @@ impl TestEnvironment {

pub async fn controller_with_forward_auth(&mut self, forward_auth: &Url) -> ControllerServer {
let db = self.db().await;
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let url: Url = format!("http://{}", listener.local_addr().unwrap())
.parse()
.unwrap();
Expand Down
2 changes: 1 addition & 1 deletion plane/plane-tests/tests/forward_auth.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use common::{auth_mock::MockAuthServer, test_env::TestEnvironment};
use hyper::StatusCode;
use plane::client::{PlaneClient, PlaneClientError};
use plane_test_macro::plane_test;
use reqwest::StatusCode;

mod common;

Expand Down
20 changes: 9 additions & 11 deletions plane/src/client/sse.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use super::PlaneClientError;
use crate::util::ExponentialBackoff;
use reqwest::{
header::{ACCEPT, CONNECTION},
header::{HeaderValue, ACCEPT, CONNECTION},
Client, Response,
};
use serde::de::DeserializeOwned;
use std::marker::PhantomData;
use tungstenite::http::HeaderValue;
use url::Url;

struct RawSseStream {
Expand Down Expand Up @@ -177,15 +176,15 @@ mod tests {
use async_stream::stream;
use axum::{
extract::State,
http::HeaderMap,
response::sse::{Event, KeepAlive, Sse},
routing::get,
Router,
};
use futures_util::stream::Stream;
use reqwest::header::HeaderMap;
use serde::{Deserialize, Serialize};
use std::{convert::Infallible, time::Duration};
use tokio::{sync::broadcast, task::JoinHandle, time::timeout};
use tokio::{net::TcpListener, sync::broadcast, task::JoinHandle, time::timeout};

#[derive(Serialize, Deserialize, Debug)]
struct Count {
Expand Down Expand Up @@ -237,18 +236,17 @@ mod tests {
}

impl DemoSseServer {
fn new() -> Self {
async fn new() -> Self {
let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 0));
let listener = std::net::TcpListener::bind(addr).unwrap();
let listener = TcpListener::bind(addr).await.unwrap();
let port = listener.local_addr().unwrap().port();
let (disconnect_sender, _) = broadcast::channel::<()>(1);

let app = Router::new()
.route("/counter", get(handle_sse))
.with_state(disconnect_sender.clone());
let server = axum::Server::from_tcp(listener)
.unwrap()
.serve(app.into_make_service());

let server = axum::serve(listener, app.into_make_service());
let handle = tokio::spawn(async move { server.await.map_err(anyhow::Error::new) });

Self {
Expand All @@ -270,7 +268,7 @@ mod tests {

#[tokio::test]
async fn test_simple_sse() {
let server = DemoSseServer::new();
let server = DemoSseServer::new().await;

let client = reqwest::Client::new();
let mut stream = super::sse_request::<Count>(server.url(), client)
Expand All @@ -285,7 +283,7 @@ mod tests {

#[tokio::test]
async fn test_sse_reconnect() {
let server = DemoSseServer::new();
let server = DemoSseServer::new().await;

let client = reqwest::Client::new();
let mut stream = super::sse_request::<Count>(server.url(), client)
Expand Down
3 changes: 1 addition & 2 deletions plane/src/controller/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use super::Controller;
use crate::controller::error::IntoApiError;
use crate::database::connect::ConnectError;
use crate::types::{ConnectRequest, ConnectResponse, RevokeRequest};
use axum::{extract::State, response::Response, Json};
use reqwest::StatusCode;
use axum::{extract::State, http::StatusCode, response::Response, Json};

fn connect_error_to_response(connect_error: &ConnectError) -> Response {
match connect_error {
Expand Down
2 changes: 1 addition & 1 deletion plane/src/controller/error.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::util::random_string;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use std::{
error::Error,
Expand Down
33 changes: 14 additions & 19 deletions plane/src/controller/forward_auth.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use axum::{
body::{Body, BoxBody, Bytes},
extract::State,
http::{request, HeaderValue, Request},
body::{Body, Bytes},
extract::{Request, State},
http::Uri,
http::{request, HeaderValue, StatusCode},
middleware::Next,
response::Response,
};
use hyper::{Client, StatusCode, Uri};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use url::Url;

pub fn clone_request_with_empty_body(parts: &request::Parts) -> request::Request<Body> {
pub fn clone_request_with_empty_body(parts: &request::Parts) -> hyper::http::Request<Body> {
// Copy method and URL.
let mut builder = request::Builder::new()
let mut builder = hyper::http::request::Builder::new()
.method(parts.method.clone())
.uri(parts.uri.clone());

Expand All @@ -32,11 +36,7 @@ pub fn clone_request_with_empty_body(parts: &request::Parts) -> request::Request
.expect("Request is always valid.")
}

pub async fn forward_layer<B>(
State(forward_url): State<Url>,
req: Request<B>,
next: Next<B>,
) -> Response<BoxBody> {
pub async fn forward_layer(State(forward_url): State<Url>, req: Request, next: Next) -> Response {
let (parts, body) = req.into_parts();
let mut forward_req = clone_request_with_empty_body(&parts);
let req = Request::from_parts(parts, body);
Expand All @@ -48,7 +48,7 @@ pub async fn forward_layer<B>(
*forward_req.uri_mut() = uri;

// Create a client
let client = Client::new();
let client = Client::builder(TokioExecutor::new()).build(HttpConnector::new());

// Forward the request
let forwarded_resp = client.request(forward_req).await;
Expand All @@ -68,13 +68,8 @@ pub async fn forward_layer<B>(
}
}

fn response_helper(status: StatusCode, body: &'static [u8]) -> Response<BoxBody> {
// This is a bit ugly. There seems to be no way to construct an http_body with an axum::Error error type (?),
// but we can use map_err from http_body::Body to convert the hyper::error::Error to an axum::Error.
// Then, we need to box it up for Axum.
let body = http_body::Full::new(Bytes::from_static(body));
let body = http_body::Body::map_err(body, axum::Error::new);
let body: BoxBody = BoxBody::new(body);
fn response_helper(status: StatusCode, body: &'static [u8]) -> Response {
let body = Body::from(Bytes::from_static(body));

Response::builder()
.status(status.as_u16())
Expand Down
33 changes: 20 additions & 13 deletions plane/src/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ use anyhow::{Context, Result};
use axum::{
extract::State,
http::{header, Method},
middleware::from_fn_with_state,
response::Response,
routing::{get, post},
Json, Router, Server,
Json, Router,
};
use forward_auth::forward_layer;
use futures_util::never::Never;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::net::{SocketAddr, TcpListener};
use std::net::SocketAddr;
use tokio::{
net::TcpListener,
sync::oneshot::{self},
task::JoinHandle,
};
Expand Down Expand Up @@ -146,13 +146,13 @@ pub struct ControllerServer {
heartbeat_handle: HeartbeatSender,
// server_handle is wrapped in an Option<> because we need to take ownership of it to join it
// when gracefully terminating.
server_handle: Option<JoinHandle<hyper::Result<()>>>,
server_handle: Option<JoinHandle<Result<(), std::io::Error>>>,
_cleanup_handle: GuardHandle,
}

impl ControllerServer {
pub async fn run(config: ControllerConfig) -> Result<Self> {
let listener = TcpListener::bind(config.bind_addr)?;
let listener = TcpListener::bind(config.bind_addr).await?;

tracing::info!("Attempting to connect to database...");

Expand Down Expand Up @@ -239,7 +239,11 @@ impl ControllerServer {
if let Some(forward_auth_url) = forward_auth {
tracing::info!(?forward_auth_url, "Forward auth enabled");
let forward_url = forward_auth_url.clone();
control_routes = control_routes.layer(from_fn_with_state(forward_url, forward_layer));

control_routes = control_routes.layer(axum::middleware::from_fn_with_state(
forward_url.clone(),
forward_layer,
));
}

let cors_public = CorsLayer::new()
Expand All @@ -265,13 +269,16 @@ impl ControllerServer {
.layer(trace_layer)
.with_state(controller);

let server_handle = tokio::spawn(
Server::from_tcp(listener)?
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(async {
graceful_terminate_receiver.await.ok();
}),
);
let server_handle = tokio::spawn(async {
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(async {
graceful_terminate_receiver.await.ok();
})
.await
});

Ok(Self {
graceful_terminate_sender: Some(graceful_terminate_sender),
Expand Down
3 changes: 1 addition & 2 deletions plane/src/proxy/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,9 @@ pub fn get_and_maybe_remove_bearer_token(parts: &mut uri::Parts) -> Option<Beare

#[cfg(test)]
mod tests {
use uri::Uri;

use super::*;
use std::str::FromStr;
use uri::Uri;

#[test]
fn no_subdomains() {
Expand Down

0 comments on commit 27aaf91

Please sign in to comment.