diff --git a/Cargo.lock b/Cargo.lock index c028bb4..a427512 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,9 +380,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project-lite" -version = "0.2.13" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" [[package]] name = "powerfmt" @@ -416,6 +416,7 @@ dependencies = [ "clap", "directories", "percent-encoding", + "pin-project-lite", "serde", "serde_derive", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index ee3a792..232294f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ anyhow = "1.0.53" clap = { version = "4.3.0", features = ["derive", "env"] } directories = "4.0.1" percent-encoding = "2.3.1" +pin-project-lite = "0.2.14" serde = { version = "1.0.186" } serde_derive = { version = "1.0.186" } serde_json = "1.0.78" diff --git a/src/client.rs b/src/client.rs index 3b5d1e2..fbe4be9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -6,8 +6,6 @@ use anyhow::{bail, ensure, Context, Result}; use percent_encoding::percent_decode_str; use serde_json::Value; use tokio::io::BufReader; -use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; -use tokio::net::TcpStream; use tokio::sync::mpsc::error::SendError; use tokio::sync::{mpsc, Mutex}; use tokio::task; @@ -21,11 +19,12 @@ use crate::lsp::jsonrpc::{ }; use crate::lsp::transport::{LspReader, LspWriter}; use crate::lsp::InitializeParams; +use crate::socketwrapper::{OwnedReadHalf, OwnedWriteHalf, Stream}; /// Read first client message and dispatch lsp mux commands pub async fn process( - socket: TcpStream, - port: u16, + socket: Stream, + client_id: usize, instance_map: Arc>, ) -> Result<()> { let (socket_read, socket_write) = socket.into_split(); @@ -70,7 +69,7 @@ pub async fn process( cwd, } => { connect( - port, + client_id, instance_map, (server, args, env, cwd), req, @@ -87,18 +86,18 @@ pub async fn process( #[derive(Clone)] pub struct Client { - port: u16, + id: usize, sender: mpsc::Sender, } impl Client { - fn new(port: u16) -> (Client, mpsc::Receiver) { + fn new(id: usize) -> (Client, mpsc::Receiver) { let (sender, receiver) = mpsc::channel(16); - (Client { port, sender }, receiver) + (Client { id, sender }, receiver) } - pub fn port(&self) -> u16 { - self.port + pub fn id(&self) -> usize { + self.id } /// Send a message to the client channel @@ -168,7 +167,7 @@ async fn reload( /// Find or spawn a language server instance and connect the client to it async fn connect( - port: u16, + client_id: usize, instance_map: Arc>, (server, args, env, cwd): ( String, @@ -224,7 +223,7 @@ async fn connect( } info!("initialized client"); - let (client, client_rx) = Client::new(port); + let (client, client_rx) = Client::new(client_id); task::spawn(input_task(client_rx, writer).in_current_span()); instance.add_client(client.clone()).await; @@ -346,7 +345,7 @@ async fn output_task( } Message::Request(mut req) => { - req.id = req.id.tag(Tag::Port(client.port)); + req.id = req.id.tag(Tag::ClientId(client.id)); if instance.send_message(req.into()).await.is_err() { break; } @@ -372,13 +371,13 @@ async fn output_task( } Message::Notification(notif) if notif.method == "textDocument/didOpen" => { - if let Err(err) = instance.open_file(client.port, notif.params).await { + if let Err(err) = instance.open_file(client.id, notif.params).await { warn!(?err, "error opening file"); } } Message::Notification(notif) if notif.method == "textDocument/didClose" => { - if let Err(err) = instance.close_file(client.port, notif.params).await { + if let Err(err) = instance.close_file(client.id, notif.params).await { warn!(?err, "error closing file"); } } diff --git a/src/config.rs b/src/config.rs index 9f7f463..00bafc6 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,8 @@ use std::collections::BTreeSet; use std::fs; use std::net::{IpAddr, Ipv4Addr}; +#[cfg(target_family = "unix")] +use std::path::PathBuf; use anyhow::{Context, Result}; use directories::ProjectDirs; @@ -21,12 +23,12 @@ mod default { 10 } - pub fn listen() -> (IpAddr, u16) { + pub fn listen() -> Address { // localhost & some random unprivileged port - (IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631) + Address::Tcp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631) } - pub fn connect() -> (IpAddr, u16) { + pub fn connect() -> Address { listen() } @@ -82,6 +84,14 @@ mod de { } } +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum Address { + Tcp(IpAddr, u16), + #[cfg(target_family = "unix")] + Unix(PathBuf), +} + #[derive(Serialize, Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct Config { @@ -94,10 +104,10 @@ pub struct Config { pub gc_interval: u32, #[serde(default = "default::listen")] - pub listen: (IpAddr, u16), + pub listen: Address, #[serde(default = "default::connect")] - pub connect: (IpAddr, u16), + pub connect: Address, #[serde(default = "default::log_filters")] pub log_filters: String, diff --git a/src/ext.rs b/src/ext.rs index 330de3f..ecb5ba8 100644 --- a/src/ext.rs +++ b/src/ext.rs @@ -3,19 +3,19 @@ use std::env; use anyhow::{bail, Context, Result}; use serde::de::{DeserializeOwned, IgnoredAny}; use tokio::io::BufReader; -use tokio::net::TcpStream; use crate::config::Config; use crate::lsp::ext::{self, LspMuxOptions, StatusResponse}; use crate::lsp::jsonrpc::{Message, Request, RequestId, Version}; use crate::lsp::transport::{LspReader, LspWriter}; use crate::lsp::{InitializationOptions, InitializeParams}; +use crate::socketwrapper::Stream; pub async fn ext_request(config: &Config, method: ext::Request) -> Result where T: DeserializeOwned, { - let (reader, writer) = TcpStream::connect(config.connect) + let (reader, writer) = Stream::connect(&config.connect) .await .context("connect")? .into_split(); @@ -102,7 +102,7 @@ pub async fn status(config: &Config, json: bool) -> Result<()> { println!(" clients:"); for client in instance.clients { println!(" - Client"); - println!(" port: {}", client.port); + println!(" id: {}", client.id); println!(" files:"); for file in client.files { println!(" - {}", file); diff --git a/src/instance.rs b/src/instance.rs index 24288dd..e586fbd 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -49,7 +49,7 @@ pub struct Instance { server: mpsc::Sender, /// Data of associated clients - clients: Mutex>, + clients: Mutex>, /// Dynamic capabilities registered by the server dynamic_capabilities: Mutex>, @@ -83,7 +83,7 @@ struct ClientData { impl ClientData { fn get_status(&self) -> ext::Client { ext::Client { - port: self.client.port(), + id: self.client.id(), files: self.files.iter().cloned().collect(), } } @@ -146,8 +146,8 @@ impl Instance { client, files: HashSet::new(), }; - if clients.insert(client.port(), client).is_some() { - unreachable!("BUG: added two clients with the same port"); + if clients.insert(client.id(), client).is_some() { + unreachable!("BUG: added two clients with the same ID"); } } @@ -157,7 +157,7 @@ impl Instance { let mut clients = self.clients.lock().await; - let Some(client) = clients.remove(&client.port()) else { + let Some(client) = clients.remove(&client.id()) else { // TODO This happens for example when the language server died while // client was still connected, and the client cleanup is attempted // with the instance being gone already. We should try notifying @@ -205,7 +205,7 @@ impl Instance { } /// Handle `textDocument/didOpen` client notification - pub async fn open_file(&self, port: u16, params: Value) -> Result<()> { + pub async fn open_file(&self, client_id: usize, params: Value) -> Result<()> { let params = serde_json::from_value::(params) .context("parsing params")?; let uri = ¶ms.text_document.uri; @@ -222,7 +222,7 @@ impl Instance { } clients - .get_mut(&port) + .get_mut(&client_id) .expect("no matching client") .files .insert(uri.clone()); @@ -241,14 +241,14 @@ impl Instance { } /// Handle `textDocument/didClose` client notification - pub async fn close_file(&self, port: u16, params: Value) -> Result<()> { + pub async fn close_file(&self, client_id: usize, params: Value) -> Result<()> { let params = serde_json::from_value::(params) .context("parsing params")?; let mut clients = self.clients.lock().await; clients - .get_mut(&port) + .get_mut(&client_id) .context("no matching client")? .files .remove(¶ms.text_document.uri); @@ -261,7 +261,7 @@ impl Instance { /// definitely closed files async fn close_all_files( &self, - clients: &HashMap, + clients: &HashMap, files: Vec, ) -> Result<()> { for uri in files { @@ -643,12 +643,12 @@ async fn stdout_task(instance: Arc, mut reader: LspReader { + (Some(Tag::ClientId(client_id)), id) => { res.id = id; - if let Some(client) = clients.get(&port) { + if let Some(client) = clients.get(&client_id) { let _ = client.send_message(res.into()).await; } else { - debug!(?port, "no matching client"); + debug!(?client_id, "no matching client"); } } (Some(Tag::Drop), _) => { @@ -664,13 +664,13 @@ async fn stdout_task(instance: Arc, mut reader: LspReader { + (Some(Tag::ClientId(client_id)), id) => { warn!(?res, "server responded with error"); res.id = id; - if let Some(client) = clients.get(&port) { + if let Some(client) = clients.get(&client_id) { let _ = client.send_message(res.into()).await; } else { - debug!(?port, "no matching client"); + debug!(?client_id, "no matching client"); } } (Some(Tag::Drop), _) => { diff --git a/src/lib.rs b/src/lib.rs index c82e045..2596910 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ mod client; mod instance; mod lsp; +mod socketwrapper; pub mod config; pub mod ext; diff --git a/src/lsp/ext.rs b/src/lsp/ext.rs index fd05bc6..c896af3 100644 --- a/src/lsp/ext.rs +++ b/src/lsp/ext.rs @@ -11,8 +11,8 @@ use super::jsonrpc::RequestId; /// Additional metadata inserted into LSP RequestId pub enum Tag { - /// Request is coming from a client connected on this port - Port(u16), + /// Request is coming from a client connected with this ID + ClientId(usize), /// Response to this request should be ignored Drop, /// Response to this request should be forwarded @@ -23,7 +23,7 @@ impl RequestId { /// Serializes the ID to a string and prepends Tag pub fn tag(&self, tag: Tag) -> RequestId { let tag = match tag { - Tag::Port(port) => format!("port:{port}"), + Tag::ClientId(client_id) => format!("client_id:{client_id}"), Tag::Drop => "drop".into(), Tag::Forward => "forward".into(), }; @@ -45,10 +45,10 @@ impl RequestId { }) } - fn parse_port(input: &str) -> Result<(u16, &str)> { - let (port, rest) = input.split_once(':').context("missing`:`")?; - let port = u16::from_str(port).context("invalid port number")?; - Ok((port, rest)) + fn parse_client_id(input: &str) -> Result<(usize, &str)> { + let (client_id, rest) = input.split_once(':').context("missing`:`")?; + let client_id = usize::from_str(client_id).context("invalid client ID")?; + Ok((client_id, rest)) } fn parse_tag(input: &RequestId) -> Result<(Tag, RequestId)> { @@ -56,10 +56,10 @@ impl RequestId { bail!("tagged id must be a String found `{input:?}`"); }; - if let Some(rest) = input.strip_prefix("port:") { - let (port, rest) = parse_port(rest)?; + if let Some(rest) = input.strip_prefix("client_id:") { + let (client_id, rest) = parse_client_id(rest)?; let inner_id = parse_inner_id(rest).context("failed to parse inner ID")?; - return Ok((Tag::Port(port), inner_id)); + return Ok((Tag::ClientId(client_id), inner_id)); } if let Some(rest) = input.strip_prefix("drop:") { @@ -175,7 +175,7 @@ pub struct Instance { #[derive(Serialize, Deserialize, Clone, Debug)] #[serde(rename_all = "camelCase")] pub struct Client { - pub port: u16, + pub id: usize, pub files: Vec, } diff --git a/src/proxy.rs b/src/proxy.rs index 3649d00..fcc2878 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -3,13 +3,13 @@ use std::env; use anyhow::{bail, Context as _, Result}; use tokio::io::{self, BufStream}; -use tokio::net::TcpStream; use crate::config::Config; use crate::lsp::ext::{LspMuxOptions, Request}; use crate::lsp::jsonrpc::Message; use crate::lsp::transport::{LspReader, LspWriter}; use crate::lsp::{InitializationOptions, InitializeParams}; +use crate::socketwrapper::Stream; pub async fn run(config: &Config, server: String, args: Vec) -> Result<()> { let cwd = env::current_dir() @@ -23,9 +23,7 @@ pub async fn run(config: &Config, server: String, args: Vec) -> Result<( } } - let mut stream = TcpStream::connect(config.connect) - .await - .context("connect")?; + let mut stream = Stream::connect(&config.connect).await.context("connect")?; let mut stdio = BufStream::new(io::join(io::stdin(), io::stdout())); // Wait for the client to send `initialize` request. diff --git a/src/server.rs b/src/server.rs index 4a037b3..8e22176 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,31 +1,35 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + use anyhow::{Context, Result}; -use tokio::net::TcpListener; use tokio::task; use tracing::{error, info, info_span, warn, Instrument}; use crate::client; use crate::config::Config; use crate::instance::InstanceMap; +use crate::socketwrapper::Listener; pub async fn run(config: &Config) -> Result<()> { let instance_map = InstanceMap::new(config).await; + let next_client_id = AtomicUsize::new(0); + let next_client_id = || next_client_id.fetch_add(1, Ordering::Relaxed); - let listener = TcpListener::bind(config.listen).await.context("listen")?; + let listener = Listener::bind(&config.listen).await.context("listen")?; loop { match listener.accept().await { - Ok((socket, addr)) => { - let port = addr.port(); + Ok((socket, _addr)) => { + let client_id = next_client_id(); let instance_map = instance_map.clone(); task::spawn( async move { info!("client connected"); - match client::process(socket, port, instance_map).await { + match client::process(socket, client_id, instance_map).await { Ok(_) => {} Err(err) => error!("client error: {err:?}"), } } - .instrument(info_span!("client", %port)), + .instrument(info_span!("client", %client_id)), ); } Err(err) => match err.kind() { diff --git a/src/socketwrapper.rs b/src/socketwrapper.rs new file mode 100644 index 0000000..67b9d51 --- /dev/null +++ b/src/socketwrapper.rs @@ -0,0 +1,265 @@ +#[cfg(target_family = "unix")] +use std::fs; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{io, net}; + +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::{tcp, TcpListener, TcpStream}; +#[cfg(target_family = "unix")] +use tokio::net::{unix, UnixListener, UnixStream}; + +use crate::config::Address; + +pub enum SocketAddr { + Ip(net::SocketAddr), + #[cfg(target_family = "unix")] + Unix(tokio::net::unix::SocketAddr), +} + +impl From for SocketAddr { + fn from(val: net::SocketAddr) -> Self { + SocketAddr::Ip(val) + } +} + +#[cfg(target_family = "unix")] +impl From for SocketAddr { + fn from(val: tokio::net::unix::SocketAddr) -> Self { + SocketAddr::Unix(val) + } +} + +#[cfg(target_family = "unix")] +pin_project! { + #[project = OwnedReadHalfProj] + pub enum OwnedReadHalf { + Tcp{#[pin] tcp: tcp::OwnedReadHalf}, + Unix{#[pin] unix: unix::OwnedReadHalf}, + } +} +#[cfg(not(target_family = "unix"))] +pin_project! { + #[project = OwnedReadHalfProj] + pub enum OwnedReadHalf { + Tcp{#[pin] tcp: tcp::OwnedReadHalf}, + } +} + +impl AsyncRead for OwnedReadHalf { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + OwnedReadHalfProj::Tcp { tcp } => tcp.poll_read(cx, buf), + #[cfg(target_family = "unix")] + OwnedReadHalfProj::Unix { unix } => unix.poll_read(cx, buf), + } + } +} + +#[cfg(target_family = "unix")] +pin_project! { + #[project = OwnedWriteHalfProj] + pub enum OwnedWriteHalf { + Tcp{#[pin] tcp: tcp::OwnedWriteHalf}, + Unix{#[pin] unix: unix::OwnedWriteHalf}, + } +} +#[cfg(not(target_family = "unix"))] +pin_project! { + #[project = OwnedWriteHalfProj] + pub enum OwnedWriteHalf { + Tcp{#[pin] tcp: tcp::OwnedWriteHalf}, + } +} + +impl AsyncWrite for OwnedWriteHalf { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_write(cx, buf), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_write_vectored(cx, bufs), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_write_vectored(cx, bufs), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_flush(cx), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + OwnedWriteHalfProj::Tcp { tcp } => tcp.poll_shutdown(cx), + #[cfg(target_family = "unix")] + OwnedWriteHalfProj::Unix { unix } => unix.poll_shutdown(cx), + } + } +} + +#[cfg(target_family = "unix")] +pin_project! { + #[project = StreamProj] + pub enum Stream { + Tcp{#[pin] tcp: TcpStream}, + Unix{#[pin] unix: UnixStream}, + } +} +#[cfg(not(target_family = "unix"))] +pin_project! { + #[project = StreamProj] + pub enum Stream { + Tcp{#[pin] tcp: TcpStream}, + } +} + +impl Stream { + pub async fn connect(addr: &Address) -> io::Result { + match addr { + Address::Tcp(ip_addr, port) => Ok(Stream::Tcp { + tcp: TcpStream::connect((*ip_addr, *port)).await?, + }), + #[cfg(target_family = "unix")] + Address::Unix(path) => Ok(Stream::Unix { + unix: UnixStream::connect(path).await?, + }), + } + } + + pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) { + match self { + Stream::Tcp { tcp } => { + let (read, write) = tcp.into_split(); + ( + OwnedReadHalf::Tcp { tcp: read }, + OwnedWriteHalf::Tcp { tcp: write }, + ) + } + #[cfg(target_family = "unix")] + Stream::Unix { unix } => { + let (read, write) = unix.into_split(); + ( + OwnedReadHalf::Unix { unix: read }, + OwnedWriteHalf::Unix { unix: write }, + ) + } + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_read(cx, buf), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_write(cx, buf), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_write(cx, buf), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_write_vectored(cx, bufs), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_write_vectored(cx, bufs), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_flush(cx), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project() { + StreamProj::Tcp { tcp } => tcp.poll_shutdown(cx), + #[cfg(target_family = "unix")] + StreamProj::Unix { unix } => unix.poll_shutdown(cx), + } + } +} + +pub enum Listener { + Tcp(TcpListener), + #[cfg(target_family = "unix")] + Unix(UnixListener), +} + +impl Listener { + pub async fn bind(addr: &Address) -> io::Result { + match addr { + Address::Tcp(ip_addr, port) => { + Ok(Listener::Tcp(TcpListener::bind((*ip_addr, *port)).await?)) + } + #[cfg(target_family = "unix")] + Address::Unix(path) => { + match fs::remove_file(&path) { + Ok(()) => (), + Err(e) if e.kind() == io::ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + Ok(Listener::Unix(UnixListener::bind(path)?)) + } + } + } + + pub async fn accept(&self) -> io::Result<(Stream, SocketAddr)> { + match self { + Listener::Tcp(tcp) => { + let (stream, addr) = tcp.accept().await?; + Ok((Stream::Tcp { tcp: stream }, addr.into())) + } + #[cfg(target_family = "unix")] + Listener::Unix(unix) => { + let (stream, addr) = unix.accept().await?; + Ok((Stream::Unix { unix: stream }, addr.into())) + } + } + } +}