From 3f0de937d0c97d31f5efbb12649ae666d94a2eb6 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Mon, 3 Jun 2024 19:24:25 +0100 Subject: [PATCH 1/7] use a unique ID per client rather than port number The current design has some issues as source port numbers: * might be non-unique even on the same machine * can end up re-used by disparate clients connecting consecutively * don't exist for unix domain sockets Since the port numbers seem to be getting used to just uniquely identify a client, it makes more sense to allocate a unique ID per connection and use that instead. --- src/client.rs | 24 ++++++++++++------------ src/ext.rs | 2 +- src/instance.rs | 32 ++++++++++++++++---------------- src/lsp/ext.rs | 22 +++++++++++----------- src/server.rs | 11 +++++++---- 5 files changed, 47 insertions(+), 44 deletions(-) diff --git a/src/client.rs b/src/client.rs index 3b5d1e2..5d1a7f8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -25,7 +25,7 @@ use crate::lsp::InitializeParams; /// Read first client message and dispatch lsp mux commands pub async fn process( socket: TcpStream, - port: u16, + client_id: usize, instance_map: Arc>, ) -> Result<()> { let (socket_read, socket_write) = socket.into_split(); @@ -70,7 +70,7 @@ pub async fn process( cwd, } => { connect( - port, + client_id, instance_map, (server, args, env, cwd), req, @@ -87,18 +87,18 @@ pub async fn process( #[derive(Clone)] pub struct Client { - port: u16, + id: usize, sender: mpsc::Sender, } impl Client { - fn new(port: u16) -> (Client, mpsc::Receiver) { + fn new(id: usize) -> (Client, mpsc::Receiver) { let (sender, receiver) = mpsc::channel(16); - (Client { port, sender }, receiver) + (Client { id, sender }, receiver) } - pub fn port(&self) -> u16 { - self.port + pub fn id(&self) -> usize { + self.id } /// Send a message to the client channel @@ -168,7 +168,7 @@ async fn reload( /// Find or spawn a language server instance and connect the client to it async fn connect( - port: u16, + client_id: usize, instance_map: Arc>, (server, args, env, cwd): ( String, @@ -224,7 +224,7 @@ async fn connect( } info!("initialized client"); - let (client, client_rx) = Client::new(port); + let (client, client_rx) = Client::new(client_id); task::spawn(input_task(client_rx, writer).in_current_span()); instance.add_client(client.clone()).await; @@ -346,7 +346,7 @@ async fn output_task( } Message::Request(mut req) => { - req.id = req.id.tag(Tag::Port(client.port)); + req.id = req.id.tag(Tag::ClientId(client.id)); if instance.send_message(req.into()).await.is_err() { break; } @@ -372,13 +372,13 @@ async fn output_task( } Message::Notification(notif) if notif.method == "textDocument/didOpen" => { - if let Err(err) = instance.open_file(client.port, notif.params).await { + if let Err(err) = instance.open_file(client.id, notif.params).await { warn!(?err, "error opening file"); } } Message::Notification(notif) if notif.method == "textDocument/didClose" => { - if let Err(err) = instance.close_file(client.port, notif.params).await { + if let Err(err) = instance.close_file(client.id, notif.params).await { warn!(?err, "error closing file"); } } diff --git a/src/ext.rs b/src/ext.rs index 330de3f..203fc8e 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -102,7 +102,7 @@ pub async fn status(config: &Config, json: bool) -> Result<()> { println!(" clients:"); for client in instance.clients { println!(" - Client"); - println!(" port: {}", client.port); + println!(" id: {}", client.id); println!(" files:"); for file in client.files { println!(" - {}", file); diff --git a/src/instance.rs b/src/instance.rs index 24288dd..e586fbd 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -49,7 +49,7 @@ pub struct Instance { server: mpsc::Sender, /// Data of associated clients - clients: Mutex>, + clients: Mutex>, /// Dynamic capabilities registered by the server dynamic_capabilities: Mutex>, @@ -83,7 +83,7 @@ struct ClientData { impl ClientData { fn get_status(&self) -> ext::Client { ext::Client { - port: self.client.port(), + id: self.client.id(), files: self.files.iter().cloned().collect(), } } @@ -146,8 +146,8 @@ impl Instance { client, files: HashSet::new(), }; - if clients.insert(client.port(), client).is_some() { - unreachable!("BUG: added two clients with the same port"); + if clients.insert(client.id(), client).is_some() { + unreachable!("BUG: added two clients with the same ID"); } } @@ -157,7 +157,7 @@ impl Instance { let mut clients = self.clients.lock().await; - let Some(client) = clients.remove(&client.port()) else { + let Some(client) = clients.remove(&client.id()) else { // TODO This happens for example when the language server died while // client was still connected, and the client cleanup is attempted // with the instance being gone already. We should try notifying @@ -205,7 +205,7 @@ impl Instance { } /// Handle `textDocument/didOpen` client notification - pub async fn open_file(&self, port: u16, params: Value) -> Result<()> { + pub async fn open_file(&self, client_id: usize, params: Value) -> Result<()> { let params = serde_json::from_value::(params) .context("parsing params")?; let uri = ¶ms.text_document.uri; @@ -222,7 +222,7 @@ impl Instance { } clients - .get_mut(&port) + .get_mut(&client_id) .expect("no matching client") .files .insert(uri.clone()); @@ -241,14 +241,14 @@ impl Instance { } /// Handle `textDocument/didClose` client notification - pub async fn close_file(&self, port: u16, params: Value) -> Result<()> { + pub async fn close_file(&self, client_id: usize, params: Value) -> Result<()> { let params = serde_json::from_value::(params) .context("parsing params")?; let mut clients = self.clients.lock().await; clients - .get_mut(&port) + .get_mut(&client_id) .context("no matching client")? .files .remove(¶ms.text_document.uri); @@ -261,7 +261,7 @@ impl Instance { /// definitely closed files async fn close_all_files( &self, - clients: &HashMap, + clients: &HashMap, files: Vec, ) -> Result<()> { for uri in files { @@ -643,12 +643,12 @@ async fn stdout_task(instance: Arc, mut reader: LspReader { + (Some(Tag::ClientId(client_id)), id) => { res.id = id; - if let Some(client) = clients.get(&port) { + if let Some(client) = clients.get(&client_id) { let _ = client.send_message(res.into()).await; } else { - debug!(?port, "no matching client"); + debug!(?client_id, "no matching client"); } } (Some(Tag::Drop), _) => { @@ -664,13 +664,13 @@ async fn stdout_task(instance: Arc, mut reader: LspReader { + (Some(Tag::ClientId(client_id)), id) => { warn!(?res, "server responded with error"); res.id = id; - if let Some(client) = clients.get(&port) { + if let Some(client) = clients.get(&client_id) { let _ = client.send_message(res.into()).await; } else { - debug!(?port, "no matching client"); + debug!(?client_id, "no matching client"); } } (Some(Tag::Drop), _) => { diff --git a/src/lsp/ext.rs b/src/lsp/ext.rs index fd05bc6..c896af3 100644 --- a/src/lsp/ext.rs +++ b/src/lsp/ext.rs @@ -11,8 +11,8 @@ use super::jsonrpc::RequestId; /// Additional metadata inserted into LSP RequestId pub enum Tag { - /// Request is coming from a client connected on this port - Port(u16), + /// Request is coming from a client connected with this ID + ClientId(usize), /// Response to this request should be ignored Drop, /// Response to this request should be forwarded @@ -23,7 +23,7 @@ impl RequestId { /// Serializes the ID to a string and prepends Tag pub fn tag(&self, tag: Tag) -> RequestId { let tag = match tag { - Tag::Port(port) => format!("port:{port}"), + Tag::ClientId(client_id) => format!("client_id:{client_id}"), Tag::Drop => "drop".into(), Tag::Forward => "forward".into(), }; @@ -45,10 +45,10 @@ impl RequestId { }) } - fn parse_port(input: &str) -> Result<(u16, &str)> { - let (port, rest) = input.split_once(':').context("missing`:`")?; - let port = u16::from_str(port).context("invalid port number")?; - Ok((port, rest)) + fn parse_client_id(input: &str) -> Result<(usize, &str)> { + let (client_id, rest) = input.split_once(':').context("missing`:`")?; + let client_id = usize::from_str(client_id).context("invalid client ID")?; + Ok((client_id, rest)) } fn parse_tag(input: &RequestId) -> Result<(Tag, RequestId)> { @@ -56,10 +56,10 @@ impl RequestId { bail!("tagged id must be a String found `{input:?}`"); }; - if let Some(rest) = input.strip_prefix("port:") { - let (port, rest) = parse_port(rest)?; + if let Some(rest) = input.strip_prefix("client_id:") { + let (client_id, rest) = parse_client_id(rest)?; let inner_id = parse_inner_id(rest).context("failed to parse inner ID")?; - return Ok((Tag::Port(port), inner_id)); + return Ok((Tag::ClientId(client_id), inner_id)); } if let Some(rest) = input.strip_prefix("drop:") { @@ -175,7 +175,7 @@ pub struct Instance { #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct Client { - pub port: u16, + pub id: usize, pub files: Vec, } diff --git a/src/server.rs b/src/server.rs index 4a037b3..e69e34b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + use anyhow::{Context, Result}; use tokio::net::TcpListener; use tokio::task; @@ -9,23 +11,24 @@ use crate::instance::InstanceMap; pub async fn run(config: &Config) -> Result<()> { let instance_map = InstanceMap::new(config).await; + let next_client_id = AtomicUsize::new(0); let listener = TcpListener::bind(config.listen).await.context("listen")?; loop { match listener.accept().await { - Ok((socket, addr)) => { - let port = addr.port(); + Ok((socket, _addr)) => { + let client_id = next_client_id.fetch_add(1, Ordering::Relaxed); let instance_map = instance_map.clone(); task::spawn( async move { info!("client connected"); - match client::process(socket, port, instance_map).await { + match client::process(socket, client_id, instance_map).await { Ok(_) => {} Err(err) => error!("client error: {err:?}"), } } - .instrument(info_span!("client", %port)), + .instrument(info_span!("client", %client_id)), ); } Err(err) => match err.kind() { From b3d0b3cd763306601f95316fdf8a4f745ef33376 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Mon, 3 Jun 2024 19:31:18 +0100 Subject: [PATCH 2/7] implement wrappers around tcp/unix,listen/stream This allows the wrappers to be used in all circumstances where TcpListen, TcpStream, tcp::OwnedReadHalf, and tcp::OwnedWriteHalf are currently used. --- Cargo.lock | 1 + Cargo.toml | 1 + src/client.rs | 5 +- src/ext.rs | 4 +- src/lib.rs | 1 + src/proxy.rs | 4 +- src/server.rs | 6 +- src/socketwrapper.rs | 263 +++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 276 insertions(+), 9 deletions(-) create mode 100644 src/socketwrapper.rs diff --git a/Cargo.lock b/Cargo.lock index c028bb4..ad70898 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,6 +416,7 @@ dependencies = [ "clap", "directories", "percent-encoding", + "pin-project-lite", "serde", "serde_derive", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index ee3a792..1f03557 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ anyhow = "1.0.53" clap = { version = "4.3.0", features = ["derive", "env"] } directories = "4.0.1" percent-encoding = "2.3.1" +pin-project-lite = "0.2.13" serde = { version = "1.0.186" } serde_derive = { version = "1.0.186" } serde_json = "1.0.78" diff --git a/src/client.rs b/src/client.rs index 5d1a7f8..fbe4be9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,8 +6,6 @@ use anyhow::{bail, ensure, Context, Result}; use percent_encoding::percent_decode_str; use serde_json::Value; use tokio::io::BufReader; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::net::TcpStream; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Mutex}; use tokio::task; @@ -21,10 +19,11 @@ use crate::lsp::jsonrpc::{ }; use crate::lsp::transport::{LspReader, LspWriter}; use crate::lsp::InitializeParams; +use crate::socketwrapper::{OwnedReadHalf, OwnedWriteHalf, Stream}; /// Read first client message and dispatch lsp mux commands pub async fn process( - socket: TcpStream, + socket: Stream, client_id: usize, instance_map: Arc>, ) -> Result<()> { diff --git a/src/ext.rs b/src/ext.rs index 203fc8e..0089123 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -3,19 +3,19 @@ use std::env; use anyhow::{bail, Context, Result}; use serde::de::{DeserializeOwned, IgnoredAny}; use tokio::io::BufReader; -use tokio::net::TcpStream; use crate::config::Config; use crate::lsp::ext::{self, LspMuxOptions, StatusResponse}; use crate::lsp::jsonrpc::{Message, Request, RequestId, Version}; use crate::lsp::transport::{LspReader, LspWriter}; use crate::lsp::{InitializationOptions, InitializeParams}; +use crate::socketwrapper::Stream; pub async fn ext_request(config: &Config, method: ext::Request) -> Result where T: DeserializeOwned, { - let (reader, writer) = TcpStream::connect(config.connect) + let (reader, writer) = Stream::connect_tcp(config.connect) .await .context("connect")? .into_split(); diff --git a/src/lib.rs b/src/lib.rs index c82e045..2596910 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod client; mod instance; mod lsp; +mod socketwrapper; pub mod config; pub mod ext; diff --git a/src/proxy.rs b/src/proxy.rs index 3649d00..13dd35a 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -3,13 +3,13 @@ use std::env; use anyhow::{bail, Context as _, Result}; use tokio::io::{self, BufStream}; -use tokio::net::TcpStream; use crate::config::Config; use crate::lsp::ext::{LspMuxOptions, Request}; use crate::lsp::jsonrpc::Message; use crate::lsp::transport::{LspReader, LspWriter}; use crate::lsp::{InitializationOptions, InitializeParams}; +use crate::socketwrapper::Stream; pub async fn run(config: &Config, server: String, args: Vec) -> Result<()> { let cwd = env::current_dir() @@ -23,7 +23,7 @@ pub async fn run(config: &Config, server: String, args: Vec) -> Result<( } } - let mut stream = TcpStream::connect(config.connect) + let mut stream = Stream::connect_tcp(config.connect) .await .context("connect")?; let mut stdio = BufStream::new(io::join(io::stdin(), io::stdout())); diff --git a/src/server.rs b/src/server.rs index e69e34b..c5a9ae0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,19 +1,21 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use anyhow::{Context, Result}; -use tokio::net::TcpListener; use tokio::task; use tracing::{error, info, info_span, warn, Instrument}; use crate::client; use crate::config::Config; use crate::instance::InstanceMap; +use crate::socketwrapper::Listener; pub async fn run(config: &Config) -> Result<()> { let instance_map = InstanceMap::new(config).await; let next_client_id = AtomicUsize::new(0); - let listener = TcpListener::bind(config.listen).await.context("listen")?; + let listener = Listener::bind_tcp(config.listen.into()) + .await + .context("listen")?; loop { match listener.accept().await { Ok((socket, _addr)) => { diff --git a/src/socketwrapper.rs b/src/socketwrapper.rs new file mode 100644 index 0000000..71cea80 --- /dev/null +++ b/src/socketwrapper.rs @@ -0,0 +1,263 @@ +#[cfg(target_family = "unix")] +use std::fs; +#[cfg(target_family = "unix")] +use std::path::Path; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, net}; + +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::{tcp, TcpListener, TcpStream, ToSocketAddrs}; +#[cfg(target_family = "unix")] +use tokio::net::{unix, UnixListener, UnixStream}; + +pub enum SocketAddr { + Ip(net::SocketAddr), + #[cfg(target_family = "unix")] + Unix(tokio::net::unix::SocketAddr), +} + +impl From for SocketAddr { + fn from(val: net::SocketAddr) -> Self { + SocketAddr::Ip(val) + } +} + +#[cfg(target_family = "unix")] +impl From for SocketAddr { + fn from(val: tokio::net::unix::SocketAddr) -> Self { + SocketAddr::Unix(val) + } +} + +#[cfg(target_family = "unix")] +pin_project! { + #[project = OwnedReadHalfProj] + pub enum OwnedReadHalf { + Tcp{#[pin] tcp: tcp::OwnedReadHalf}, + Unix{#[pin] unix: unix::OwnedReadHalf}, + } +} +#[cfg(not(target_family = "unix"))] +pin_project! { + #[project = OwnedReadHalfProj] + pub enum OwnedReadHalf { + Tcp{#[pin] tcp: tcp::OwnedReadHalf}, + } +} + +impl AsyncRead for OwnedReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + OwnedReadHalfProj::Tcp { tcp } => tcp.poll_read(cx, buf), + #[cfg(target_family = "unix")] + OwnedReadHalfProj::Unix { unix } => unix.poll_read(cx, buf), + } + } +} + +#[cfg(target_family = "unix")] +pin_project! { + #[project = OwnedWriteHalfProj] + pub enum OwnedWriteHalf { + Tcp{#[pin] tcp: tcp::OwnedWriteHalf}, + Unix{#[pin] unix: unix::OwnedWriteHalf}, + } +} +#[cfg(not(target_family = "unix"))] +pin_project! { + #[project = OwnedWriteHalfProj] + pub enum OwnedWriteHalf { + Tcp{#[pin] tcp: tcp::OwnedWriteHalf}, + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_write(cx, buf), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_write_vectored(cx, bufs), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_write_vectored(cx, bufs), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_flush(cx), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_shutdown(cx), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_shutdown(cx), + } + } +} + +#[cfg(target_family = "unix")] +pin_project! { + #[project = StreamProj] + pub enum Stream { + Tcp{#[pin] tcp: TcpStream}, + Unix{#[pin] unix: UnixStream}, + } +} +#[cfg(not(target_family = "unix"))] +pin_project! { + #[project = StreamProj] + pub enum Stream { + Tcp{#[pin] tcp: TcpStream}, + } +} + +impl Stream { + pub async fn connect_tcp(addr: A) -> io::Result { + Ok(Stream::Tcp { + tcp: TcpStream::connect(addr).await?, + }) + } + + #[cfg(target_family = "unix")] + pub async fn connect_unix>(addr: P) -> io::Result { + Ok(Stream::Unix { + unix: UnixStream::connect(addr).await?, + }) + } + + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + match self { + Stream::Tcp { tcp } => { + let (read, write) = tcp.into_split(); + ( + OwnedReadHalf::Tcp { tcp: read }, + OwnedWriteHalf::Tcp { tcp: write }, + ) + } + #[cfg(target_family = "unix")] + Stream::Unix { unix } => { + let (read, write) = unix.into_split(); + ( + OwnedReadHalf::Unix { unix: read }, + OwnedWriteHalf::Unix { unix: write }, + ) + } + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_read(cx, buf), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_write(cx, buf), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_write_vectored(cx, bufs), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_write_vectored(cx, bufs), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_flush(cx), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_shutdown(cx), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_shutdown(cx), + } + } +} + +pub enum Listener { + Tcp(TcpListener), + #[cfg(target_family = "unix")] + Unix(UnixListener), +} + +impl Listener { + pub async fn bind_tcp(addr: net::SocketAddr) -> io::Result { + Ok(Listener::Tcp(TcpListener::bind(addr).await?)) + } + + #[cfg(target_family = "unix")] + pub fn bind_unix>(addr: T) -> io::Result { + match fs::remove_file(&addr) { + Ok(()) => (), + Err(e) if e.kind() == io::ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + Ok(Listener::Unix(UnixListener::bind(addr)?)) + } + + pub async fn accept(&self) -> io::Result<(Stream, SocketAddr)> { + match self { + Listener::Tcp(tcp) => { + let (stream, addr) = tcp.accept().await?; + Ok((Stream::Tcp { tcp: stream }, addr.into())) + } + #[cfg(target_family = "unix")] + Listener::Unix(unix) => { + let (stream, addr) = unix.accept().await?; + Ok((Stream::Unix { unix: stream }, addr.into())) + } + } + } +} From 98438dae7c25dcda706897b4cea4cfacc1981467 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Tue, 4 Jun 2024 00:27:35 +0100 Subject: [PATCH 3/7] support unix socket addresses in listen and connect --- src/config.rs | 20 +++++++++++++++----- src/ext.rs | 13 ++++++++----- src/proxy.rs | 11 +++++++---- src/server.rs | 11 +++++++---- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/src/config.rs b/src/config.rs index 9f7f463..00bafc6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,8 @@ use std::collections::BTreeSet; use std::fs; use std::net::{IpAddr, Ipv4Addr}; +#[cfg(target_family = "unix")] +use std::path::PathBuf; use anyhow::{Context, Result}; use directories::ProjectDirs; @@ -21,12 +23,12 @@ mod default { 10 } - pub fn listen() -> (IpAddr, u16) { + pub fn listen() -> Address { // localhost & some random unprivileged port - (IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631) + Address::Tcp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631) } - pub fn connect() -> (IpAddr, u16) { + pub fn connect() -> Address { listen() } @@ -82,6 +84,14 @@ mod de { } } +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum Address { + Tcp(IpAddr, u16), + #[cfg(target_family = "unix")] + Unix(PathBuf), +} + #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct Config { @@ -94,10 +104,10 @@ pub struct Config { pub gc_interval: u32, #[serde(default = "default::listen")] - pub listen: (IpAddr, u16), + pub listen: Address, #[serde(default = "default::connect")] - pub connect: (IpAddr, u16), + pub connect: Address, #[serde(default = "default::log_filters")] pub log_filters: String, diff --git a/src/ext.rs b/src/ext.rs index 0089123..22c799f 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Context, Result}; use serde::de::{DeserializeOwned, IgnoredAny}; use tokio::io::BufReader; -use crate::config::Config; +use crate::config::{Address, Config}; use crate::lsp::ext::{self, LspMuxOptions, StatusResponse}; use crate::lsp::jsonrpc::{Message, Request, RequestId, Version}; use crate::lsp::transport::{LspReader, LspWriter}; @@ -15,10 +15,13 @@ pub async fn ext_request(config: &Config, method: ext::Request) -> Result where T: DeserializeOwned, { - let (reader, writer) = Stream::connect_tcp(config.connect) - .await - .context("connect")? - .into_split(); + let (reader, writer) = match config.connect { + Address::Tcp(ip_addr, port) => Stream::connect_tcp((ip_addr, port)).await, + #[cfg(target_family = "unix")] + Address::Unix(ref path) => Stream::connect_unix(path).await, + } + .context("connect")? + .into_split(); let mut writer = LspWriter::new(writer, "lspmux"); let mut reader = LspReader::new(BufReader::new(reader), "lspmux"); diff --git a/src/proxy.rs b/src/proxy.rs index 13dd35a..318e327 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -4,7 +4,7 @@ use std::env; use anyhow::{bail, Context as _, Result}; use tokio::io::{self, BufStream}; -use crate::config::Config; +use crate::config::{Address, Config}; use crate::lsp::ext::{LspMuxOptions, Request}; use crate::lsp::jsonrpc::Message; use crate::lsp::transport::{LspReader, LspWriter}; @@ -23,9 +23,12 @@ pub async fn run(config: &Config, server: String, args: Vec) -> Result<( } } - let mut stream = Stream::connect_tcp(config.connect) - .await - .context("connect")?; + let mut stream = match config.connect { + Address::Tcp(ip_addr, port) => Stream::connect_tcp((ip_addr, port)).await, + #[cfg(target_family = "unix")] + Address::Unix(ref path) => Stream::connect_unix(path).await, + } + .context("connect")?; let mut stdio = BufStream::new(io::join(io::stdin(), io::stdout())); // Wait for the client to send `initialize` request. diff --git a/src/server.rs b/src/server.rs index c5a9ae0..956662c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,7 @@ use tokio::task; use tracing::{error, info, info_span, warn, Instrument}; use crate::client; -use crate::config::Config; +use crate::config::{Address, Config}; use crate::instance::InstanceMap; use crate::socketwrapper::Listener; @@ -13,9 +13,12 @@ pub async fn run(config: &Config) -> Result<()> { let instance_map = InstanceMap::new(config).await; let next_client_id = AtomicUsize::new(0); - let listener = Listener::bind_tcp(config.listen.into()) - .await - .context("listen")?; + let listener = match config.listen { + Address::Tcp(ip_addr, port) => Listener::bind_tcp((ip_addr, port).into()).await, + #[cfg(target_family = "unix")] + Address::Unix(ref path) => Listener::bind_unix(path), + } + .context("listen")?; loop { match listener.accept().await { Ok((socket, _addr)) => { From b17e8b8e2785d736e9b1adf1b28bc38779a74e9f Mon Sep 17 00:00:00 2001 From: max Date: Tue, 4 Jun 2024 13:32:50 +0200 Subject: [PATCH 4/7] update pin-project-lite was: replace pin-project with pin-project-lite taken from: https://github.com/EliteTK/ra-multiplex/pull/1 --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad70898..a427512 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,9 +380,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project-lite" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" [[package]] name = "powerfmt" diff --git a/Cargo.toml b/Cargo.toml index 1f03557..232294f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ anyhow = "1.0.53" clap = { version = "4.3.0", features = ["derive", "env"] } directories = "4.0.1" percent-encoding = "2.3.1" -pin-project-lite = "0.2.13" +pin-project-lite = "0.2.14" serde = { version = "1.0.186" } serde_derive = { version = "1.0.186" } serde_json = "1.0.78" From f5c3831b39fd37657f821fa2dfbe44b6230b53d2 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Tue, 4 Jun 2024 15:53:45 +0100 Subject: [PATCH 5/7] Replace connect_{tcp,unix} with a single function Since config::Address now exists, there's no real need to be generic and accept anything TcpStream::connect and UnixStream::connect accept and so a single function taking config::Address can simplify both call sites. --- src/ext.rs | 13 +++++-------- src/proxy.rs | 9 ++------- src/socketwrapper.rs | 25 +++++++++++++------------ 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/src/ext.rs b/src/ext.rs index 22c799f..ecb5ba8 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Context, Result}; use serde::de::{DeserializeOwned, IgnoredAny}; use tokio::io::BufReader; -use crate::config::{Address, Config}; +use crate::config::Config; use crate::lsp::ext::{self, LspMuxOptions, StatusResponse}; use crate::lsp::jsonrpc::{Message, Request, RequestId, Version}; use crate::lsp::transport::{LspReader, LspWriter}; @@ -15,13 +15,10 @@ pub async fn ext_request(config: &Config, method: ext::Request) -> Result where T: DeserializeOwned, { - let (reader, writer) = match config.connect { - Address::Tcp(ip_addr, port) => Stream::connect_tcp((ip_addr, port)).await, - #[cfg(target_family = "unix")] - Address::Unix(ref path) => Stream::connect_unix(path).await, - } - .context("connect")? - .into_split(); + let (reader, writer) = Stream::connect(&config.connect) + .await + .context("connect")? + .into_split(); let mut writer = LspWriter::new(writer, "lspmux"); let mut reader = LspReader::new(BufReader::new(reader), "lspmux"); diff --git a/src/proxy.rs b/src/proxy.rs index 318e327..fcc2878 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -4,7 +4,7 @@ use std::env; use anyhow::{bail, Context as _, Result}; use tokio::io::{self, BufStream}; -use crate::config::{Address, Config}; +use crate::config::Config; use crate::lsp::ext::{LspMuxOptions, Request}; use crate::lsp::jsonrpc::Message; use crate::lsp::transport::{LspReader, LspWriter}; @@ -23,12 +23,7 @@ pub async fn run(config: &Config, server: String, args: Vec) -> Result<( } } - let mut stream = match config.connect { - Address::Tcp(ip_addr, port) => Stream::connect_tcp((ip_addr, port)).await, - #[cfg(target_family = "unix")] - Address::Unix(ref path) => Stream::connect_unix(path).await, - } - .context("connect")?; + let mut stream = Stream::connect(&config.connect).await.context("connect")?; let mut stdio = BufStream::new(io::join(io::stdin(), io::stdout())); // Wait for the client to send `initialize` request. diff --git a/src/socketwrapper.rs b/src/socketwrapper.rs index 71cea80..e154da6 100644 --- a/src/socketwrapper.rs +++ b/src/socketwrapper.rs @@ -8,10 +8,12 @@ use std::{io, net}; use pin_project_lite::pin_project; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio::net::{tcp, TcpListener, TcpStream, ToSocketAddrs}; +use tokio::net::{tcp, TcpListener, TcpStream}; #[cfg(target_family = "unix")] use tokio::net::{unix, UnixListener, UnixStream}; +use crate::config::Address; + pub enum SocketAddr { Ip(net::SocketAddr), #[cfg(target_family = "unix")] @@ -136,17 +138,16 @@ pin_project! { } impl Stream { - pub async fn connect_tcp(addr: A) -> io::Result { - Ok(Stream::Tcp { - tcp: TcpStream::connect(addr).await?, - }) - } - - #[cfg(target_family = "unix")] - pub async fn connect_unix>(addr: P) -> io::Result { - Ok(Stream::Unix { - unix: UnixStream::connect(addr).await?, - }) + pub async fn connect(addr: &Address) -> io::Result { + match addr { + Address::Tcp(ip_addr, port) => Ok(Stream::Tcp { + tcp: TcpStream::connect((*ip_addr, *port)).await?, + }), + #[cfg(target_family = "unix")] + Address::Unix(path) => Ok(Stream::Unix { + unix: UnixStream::connect(path).await?, + }), + } } pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { From 5d043bd2fc6d18de2a1b46a743eda058c0670b0a Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Tue, 4 Jun 2024 15:59:37 +0100 Subject: [PATCH 6/7] Replace bind_{tcp,unix} with a single function Same reasons as the connect_{tcp,unix} commit. Although there is currently only one call site. --- src/server.rs | 9 ++------- src/socketwrapper.rs | 27 ++++++++++++++------------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/server.rs b/src/server.rs index 956662c..5eb9c75 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,7 @@ use tokio::task; use tracing::{error, info, info_span, warn, Instrument}; use crate::client; -use crate::config::{Address, Config}; +use crate::config::Config; use crate::instance::InstanceMap; use crate::socketwrapper::Listener; @@ -13,12 +13,7 @@ pub async fn run(config: &Config) -> Result<()> { let instance_map = InstanceMap::new(config).await; let next_client_id = AtomicUsize::new(0); - let listener = match config.listen { - Address::Tcp(ip_addr, port) => Listener::bind_tcp((ip_addr, port).into()).await, - #[cfg(target_family = "unix")] - Address::Unix(ref path) => Listener::bind_unix(path), - } - .context("listen")?; + let listener = Listener::bind(&config.listen).await.context("listen")?; loop { match listener.accept().await { Ok((socket, _addr)) => { diff --git a/src/socketwrapper.rs b/src/socketwrapper.rs index e154da6..67b9d51 100644 --- a/src/socketwrapper.rs +++ b/src/socketwrapper.rs @@ -1,7 +1,5 @@ #[cfg(target_family = "unix")] use std::fs; -#[cfg(target_family = "unix")] -use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, net}; @@ -234,18 +232,21 @@ pub enum Listener { } impl Listener { - pub async fn bind_tcp(addr: net::SocketAddr) -> io::Result { - Ok(Listener::Tcp(TcpListener::bind(addr).await?)) - } - - #[cfg(target_family = "unix")] - pub fn bind_unix>(addr: T) -> io::Result { - match fs::remove_file(&addr) { - Ok(()) => (), - Err(e) if e.kind() == io::ErrorKind::NotFound => (), - Err(e) => return Err(e), + pub async fn bind(addr: &Address) -> io::Result { + match addr { + Address::Tcp(ip_addr, port) => { + Ok(Listener::Tcp(TcpListener::bind((*ip_addr, *port)).await?)) + } + #[cfg(target_family = "unix")] + Address::Unix(path) => { + match fs::remove_file(&path) { + Ok(()) => (), + Err(e) if e.kind() == io::ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + Ok(Listener::Unix(UnixListener::bind(path)?)) + } } - Ok(Listener::Unix(UnixListener::bind(addr)?)) } pub async fn accept(&self) -> io::Result<(Stream, SocketAddr)> { From 7e486b146b221fab968659dcfd83634b9fb97d69 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Tue, 4 Jun 2024 18:45:34 +0100 Subject: [PATCH 7/7] make next_client_id a lambda This is neater at the call site and makes it clear that this is only ever supposed to be an incremental ID and nothing else. As suggested by max (pr2502) in https://github.com/pr2502/ra-multiplex/pull/66#discussion_r1626394771 --- src/server.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/server.rs b/src/server.rs index 5eb9c75..8e22176 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,12 +12,13 @@ use crate::socketwrapper::Listener; pub async fn run(config: &Config) -> Result<()> { let instance_map = InstanceMap::new(config).await; let next_client_id = AtomicUsize::new(0); + let next_client_id = || next_client_id.fetch_add(1, Ordering::Relaxed); let listener = Listener::bind(&config.listen).await.context("listen")?; loop { match listener.accept().await { Ok((socket, _addr)) => { - let client_id = next_client_id.fetch_add(1, Ordering::Relaxed); + let client_id = next_client_id(); let instance_map = instance_map.clone(); task::spawn(