Skip to content

Commit

Permalink
add unix socket support
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed Aug 22, 2024
1 parent 9a7327d commit 12a258e
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 42 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.1.0"
edition = "2021"

[dependencies]
anyhow = "1.0"
async-trait = "0.1"
axum = "0.7"
base64 = "0.22"
Expand Down
3 changes: 2 additions & 1 deletion src/http/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{fmt, sync::Arc, time::Duration};

use anyhow::Context;
use async_trait::async_trait;
use http::header::HeaderValue;
use reqwest::dns::Resolve;
Expand Down Expand Up @@ -39,7 +40,7 @@ pub fn new(opts: Options, dns_resolver: impl Resolve + 'static) -> Result<reqwes
.redirect(reqwest::redirect::Policy::none())
.no_proxy()
.build()
.map_err(|e| Error::Generic(format!("unable to create Reqwest client: {e:#}")))?;
.context("unable to create reqwest client")?;

Ok(client)
}
Expand Down
6 changes: 3 additions & 3 deletions src/http/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{
sync::Arc,
};

use anyhow::Context;
use async_trait::async_trait;
use hickory_proto::rr::RecordType;
use hickory_resolver::{
Expand Down Expand Up @@ -96,14 +97,13 @@ impl Resolve for Resolver {
#[async_trait]
impl Resolves for Resolver {
async fn resolve(&self, name: &str, record: &str) -> Result<Vec<(String, String)>, Error> {
let record_type = RecordType::from_str(record)
.map_err(|e| Error::Generic(format!("unable to parse record: {e:#}")))?;
let record_type = RecordType::from_str(record).context("unable to parse record")?;

let lookup = self
.0
.lookup(name, record_type)
.await
.map_err(|e| Error::Generic(format!("lookup failed: {e:#}")))?;
.context("lookup failed")?;

let rr = lookup
.into_iter()
Expand Down
5 changes: 3 additions & 2 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ pub const ALPN_ACME: &[u8] = b"acme-tls/1";
/// TODO improve
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("{0}")]
Generic(String),
//#[error("{0}")]
#[error(transparent)]
Generic(#[from] anyhow::Error),
}

// Calculate very approximate HTTP request/response headers size in bytes.
Expand Down
171 changes: 135 additions & 36 deletions src/http/server.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::{
fmt::Display,
net::SocketAddr,
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::PathBuf,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};

use anyhow::{anyhow, Context};
use axum::{extract::Request, Router};
use hyper::body::Incoming;
use hyper_util::{
Expand All @@ -21,7 +24,7 @@ use prometheus::{
use rustls::{server::ServerConnection, CipherSuite, ProtocolVersion};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
net::{TcpListener, TcpSocket, TcpStream},
net::{TcpListener, TcpSocket, UnixListener, UnixSocket},
select,
time::sleep,
};
Expand Down Expand Up @@ -140,11 +143,11 @@ impl TryFrom<&ServerConnection> for TlsInfo {
.map(|x| String::from_utf8_lossy(x).to_string()),
protocol: c
.protocol_version()
.ok_or_else(|| Error::Generic("No TLS protocol found".into()))?,
.ok_or_else(|| anyhow!("No TLS protocol found"))?,
cipher: c
.negotiated_cipher_suite()
.map(|x| x.suite())
.ok_or_else(|| Error::Generic("No TLS ciphersuite found".into()))?,
.ok_or_else(|| anyhow!("No TLS ciphersuite found"))?,
})
}
}
Expand All @@ -153,7 +156,7 @@ impl TryFrom<&ServerConnection> for TlsInfo {
pub struct ConnInfo {
pub id: Uuid,
pub accepted_at: Instant,
pub remote_addr: SocketAddr,
pub remote_addr: RemoteAddr,
pub traffic: Arc<Stats>,
pub req_count: AtomicU64,
pub close: CancellationToken,
Expand All @@ -169,9 +172,94 @@ impl ConnInfo {
}
}

pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}

impl Listener {
async fn accept(&self) -> Result<(Box<dyn AsyncReadWrite>, RemoteAddr), io::Error> {
Ok(match self {
Self::Tcp(v) => {
let x = v.accept().await?;
// Disable Nagle's algo
x.0.set_nodelay(true)?;
(Box::new(x.0), RemoteAddr::Tcp(x.1))
}
Self::Unix(v) => {
let x = v.accept().await?;
(
Box::new(x.0),
RemoteAddr::Unix(
x.1.as_pathname()
.map(|x| x.to_string_lossy().to_string())
.unwrap_or_default(),
),
)
}
})
}
}

#[derive(Debug, Clone)]
pub enum LocalAddr {
Tcp(SocketAddr),
Unix(PathBuf),
}

impl Display for LocalAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
Self::Tcp(v) => v.to_string(),
Self::Unix(v) => v.to_string_lossy().to_string(),
}
)
}
}

#[derive(Debug, Clone)]
pub enum RemoteAddr {
Tcp(SocketAddr),
Unix(String),
}

impl RemoteAddr {
pub fn family(&self) -> &str {
match self {
Self::Tcp(v) => {
if v.is_ipv4() {
"v4"
} else {
"v6"
}
}
Self::Unix(_) => "unix",
}
}

pub fn ip(&self) -> IpAddr {
match self {
Self::Tcp(v) => v.ip(),
Self::Unix(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
}
}
}

impl Display for RemoteAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tcp(v) => write!(f, "{v}"),
Self::Unix(v) => write!(f, "{v}"),
}
}
}

struct Conn {
addr: SocketAddr,
remote_addr: SocketAddr,
addr: LocalAddr,
remote_addr: RemoteAddr,
router: Router,
builder: Builder<TokioExecutor>,
token: CancellationToken,
Expand Down Expand Up @@ -202,7 +290,7 @@ impl Conn {
.unwrap()
.accept(stream)
.await
.map_err(|e| Error::Generic(format!("unable to accept TLS: {e:#}")))?;
.context("unable to accept TLS")?;
let duration = start.elapsed();

let conn = stream.get_ref().1;
Expand All @@ -222,7 +310,7 @@ impl Conn {
Ok((stream, tls_info))
}

async fn handle(&self, stream: TcpStream) -> Result<(), Error> {
async fn handle(&self, stream: Box<dyn AsyncReadWrite>) -> Result<(), Error> {
let accepted_at = Instant::now();

debug!("{}: got a new connection", self);
Expand All @@ -236,11 +324,7 @@ impl Conn {
} else {
"no"
}, // Is TLS
if self.remote_addr.is_ipv4() {
"v4"
} else {
"v6"
}, // IP Family
self.remote_addr.family(),
"no",
];

Expand All @@ -249,18 +333,13 @@ impl Conn {
.with_label_values(&labels[0..3])
.inc();

// Disable Nagle's algo
stream
.set_nodelay(true)
.map_err(|_| Error::Generic("unable to set TCP_NODELAY".into()))?;

// Wrap with traffic counter
let (stream, stats) = AsyncCounter::new(stream);

let conn_info = Arc::new(ConnInfo {
id: Uuid::now_v7(),
accepted_at,
remote_addr: self.remote_addr,
remote_addr: self.remote_addr.clone(),
traffic: stats.clone(),
req_count: AtomicU64::new(0),
close: self.token_close.clone(),
Expand Down Expand Up @@ -326,7 +405,7 @@ impl Conn {
stream
.shutdown()
.await
.map_err(|e| Error::Generic(format!("unable to shutdown stream: {e:#}")))?;
.context("unable to shutdown stream")?;

return Ok(());
}
Expand Down Expand Up @@ -383,7 +462,7 @@ impl Conn {

v = conn.as_mut() => {
if let Err(e) = v {
return Err(Error::Generic(format!("Unable to serve connection: {e:#}")));
return Err(anyhow!("unable to serve connection: {e:#}").into());
}
},
}
Expand All @@ -394,7 +473,7 @@ impl Conn {

// Listens for new connections on addr with an optional TLS and serves provided Router
pub struct Server {
addr: SocketAddr,
addr: LocalAddr,
router: Router,
tracker: TaskTracker,
options: Options,
Expand All @@ -404,7 +483,7 @@ pub struct Server {

impl Server {
pub fn new(
addr: SocketAddr,
addr: LocalAddr,
router: Router,
options: Options,
metrics: Metrics,
Expand All @@ -420,8 +499,17 @@ impl Server {
}
}

fn listen(&self) -> Result<Listener, Error> {
Ok(match &self.addr {
LocalAddr::Tcp(v) => Listener::Tcp(listen_tcp_backlog(*v, self.options.backlog)?),
LocalAddr::Unix(v) => {
Listener::Unix(listen_unix_backlog(v.clone(), self.options.backlog)?)
}
})
}

pub async fn serve(&self, token: CancellationToken) -> Result<(), Error> {
let listener = listen_tcp_backlog(self.addr, self.options.backlog)?;
let listener = self.listen()?;

// Prepare Hyper connection builder
// It automatically figures out whether to do HTTP1 or HTTP2
Expand Down Expand Up @@ -483,8 +571,8 @@ impl Server {
// Router & TlsAcceptor are both Arc<> inside so it's cheap to clone
// Builder is a bit more complex, but cloning is better than to create it again
let conn = Conn {
addr: self.addr,
remote_addr,
addr: self.addr.clone(),
remote_addr: remote_addr.clone(),
router: self.router.clone(),
builder: builder.clone(),
token: token.child_token(),
Expand All @@ -511,22 +599,33 @@ impl Server {
}
}

// Creates a listener with a backlog set
// Creates a TCP listener with a backlog set
pub fn listen_tcp_backlog(addr: SocketAddr, backlog: u32) -> Result<TcpListener, Error> {
let socket = match addr {
SocketAddr::V4(_) => TcpSocket::new_v4(),
SocketAddr::V6(_) => TcpSocket::new_v6(),
}
.map_err(|e| Error::Generic(format!("unable to open socket: {e:#}")))?;
.context("unable to open socket")?;

socket
.set_reuseaddr(true)
.map_err(|e| Error::Generic(format!("unable to set SO_REUSEADDR: {e:#}")))?;
socket
.bind(addr)
.map_err(|e| Error::Generic(format!("unable to bind socket: {e:#}")))?;
.context("unable to set SO_REUSEADDR")?;
socket.bind(addr).context("unable to bind socket")?;

socket
.listen(backlog)
.map_err(|e| Error::Generic(format!("unable to listen socket: {e:#}")))
let socket = socket.listen(backlog).context("unable to listen socket")?;
Ok(socket)
}

// Creates a Unix Socket listener with a backlog set
pub fn listen_unix_backlog(path: PathBuf, backlog: u32) -> Result<UnixListener, Error> {
let socket = UnixSocket::new_stream().context("unable to open UNIX socket")?;

if path.exists() {
std::fs::remove_file(&path).context("unable to remove UNIX socket")?;
}

socket.bind(path).context("unable to bind socket")?;

let socket = socket.listen(backlog).context("unable to listen socket")?;
Ok(socket)
}

0 comments on commit 12a258e

Please sign in to comment.