Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unix sockets (implements #42) #66

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ anyhow = "1.0.53"
clap = { version = "4.3.0", features = ["derive", "env"] }
directories = "4.0.1"
percent-encoding = "2.3.1"
pin-project-lite = "0.2.14"
serde = { version = "1.0.186" }
serde_derive = { version = "1.0.186" }
serde_json = "1.0.78"
Expand Down
29 changes: 14 additions & 15 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ use anyhow::{bail, ensure, Context, Result};
use percent_encoding::percent_decode_str;
use serde_json::Value;
use tokio::io::BufReader;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Mutex};
use tokio::task;
Expand All @@ -21,11 +19,12 @@ use crate::lsp::jsonrpc::{
};
use crate::lsp::transport::{LspReader, LspWriter};
use crate::lsp::InitializeParams;
use crate::socketwrapper::{OwnedReadHalf, OwnedWriteHalf, Stream};

/// Read first client message and dispatch lsp mux commands
pub async fn process(
socket: TcpStream,
port: u16,
socket: Stream,
client_id: usize,
instance_map: Arc<Mutex<InstanceMap>>,
) -> Result<()> {
let (socket_read, socket_write) = socket.into_split();
Expand Down Expand Up @@ -70,7 +69,7 @@ pub async fn process(
cwd,
} => {
connect(
port,
client_id,
instance_map,
(server, args, env, cwd),
req,
Expand All @@ -87,18 +86,18 @@ pub async fn process(

#[derive(Clone)]
pub struct Client {
port: u16,
id: usize,
sender: mpsc::Sender<Message>,
}

impl Client {
fn new(port: u16) -> (Client, mpsc::Receiver<Message>) {
fn new(id: usize) -> (Client, mpsc::Receiver<Message>) {
let (sender, receiver) = mpsc::channel(16);
(Client { port, sender }, receiver)
(Client { id, sender }, receiver)
}

pub fn port(&self) -> u16 {
self.port
pub fn id(&self) -> usize {
self.id
}

/// Send a message to the client channel
Expand Down Expand Up @@ -168,7 +167,7 @@ async fn reload(

/// Find or spawn a language server instance and connect the client to it
async fn connect(
port: u16,
client_id: usize,
instance_map: Arc<Mutex<InstanceMap>>,
(server, args, env, cwd): (
String,
Expand Down Expand Up @@ -224,7 +223,7 @@ async fn connect(
}
info!("initialized client");

let (client, client_rx) = Client::new(port);
let (client, client_rx) = Client::new(client_id);
task::spawn(input_task(client_rx, writer).in_current_span());
instance.add_client(client.clone()).await;

Expand Down Expand Up @@ -346,7 +345,7 @@ async fn output_task(
}

Message::Request(mut req) => {
req.id = req.id.tag(Tag::Port(client.port));
req.id = req.id.tag(Tag::ClientId(client.id));
if instance.send_message(req.into()).await.is_err() {
break;
}
Expand All @@ -372,13 +371,13 @@ async fn output_task(
}

Message::Notification(notif) if notif.method == "textDocument/didOpen" => {
if let Err(err) = instance.open_file(client.port, notif.params).await {
if let Err(err) = instance.open_file(client.id, notif.params).await {
warn!(?err, "error opening file");
}
}

Message::Notification(notif) if notif.method == "textDocument/didClose" => {
if let Err(err) = instance.close_file(client.port, notif.params).await {
if let Err(err) = instance.close_file(client.id, notif.params).await {
warn!(?err, "error closing file");
}
}
Expand Down
20 changes: 15 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::collections::BTreeSet;
use std::fs;
use std::net::{IpAddr, Ipv4Addr};
#[cfg(target_family = "unix")]
use std::path::PathBuf;

use anyhow::{Context, Result};
use directories::ProjectDirs;
Expand All @@ -21,12 +23,12 @@ mod default {
10
}

pub fn listen() -> (IpAddr, u16) {
pub fn listen() -> Address {
// localhost & some random unprivileged port
(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631)
Address::Tcp(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 27_631)
}

pub fn connect() -> (IpAddr, u16) {
pub fn connect() -> Address {
listen()
}

Expand Down Expand Up @@ -82,6 +84,14 @@ mod de {
}
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum Address {
Tcp(IpAddr, u16),
#[cfg(target_family = "unix")]
Unix(PathBuf),
}

#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct Config {
Expand All @@ -94,10 +104,10 @@ pub struct Config {
pub gc_interval: u32,

#[serde(default = "default::listen")]
pub listen: (IpAddr, u16),
pub listen: Address,

#[serde(default = "default::connect")]
pub connect: (IpAddr, u16),
pub connect: Address,

#[serde(default = "default::log_filters")]
pub log_filters: String,
Expand Down
6 changes: 3 additions & 3 deletions src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ use std::env;
use anyhow::{bail, Context, Result};
use serde::de::{DeserializeOwned, IgnoredAny};
use tokio::io::BufReader;
use tokio::net::TcpStream;

use crate::config::Config;
use crate::lsp::ext::{self, LspMuxOptions, StatusResponse};
use crate::lsp::jsonrpc::{Message, Request, RequestId, Version};
use crate::lsp::transport::{LspReader, LspWriter};
use crate::lsp::{InitializationOptions, InitializeParams};
use crate::socketwrapper::Stream;

pub async fn ext_request<T>(config: &Config, method: ext::Request) -> Result<T>
where
T: DeserializeOwned,
{
let (reader, writer) = TcpStream::connect(config.connect)
let (reader, writer) = Stream::connect(&config.connect)
.await
.context("connect")?
.into_split();
Expand Down Expand Up @@ -102,7 +102,7 @@ pub async fn status(config: &Config, json: bool) -> Result<()> {
println!(" clients:");
for client in instance.clients {
println!(" - Client");
println!(" port: {}", client.port);
println!(" id: {}", client.id);
println!(" files:");
for file in client.files {
println!(" - {}", file);
Expand Down
32 changes: 16 additions & 16 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct Instance {
server: mpsc::Sender<Message>,

/// Data of associated clients
clients: Mutex<HashMap<u16, ClientData>>,
clients: Mutex<HashMap<usize, ClientData>>,

/// Dynamic capabilities registered by the server
dynamic_capabilities: Mutex<HashMap<String, lsp::Registration>>,
Expand Down Expand Up @@ -83,7 +83,7 @@ struct ClientData {
impl ClientData {
fn get_status(&self) -> ext::Client {
ext::Client {
port: self.client.port(),
id: self.client.id(),
files: self.files.iter().cloned().collect(),
}
}
Expand Down Expand Up @@ -146,8 +146,8 @@ impl Instance {
client,
files: HashSet::new(),
};
if clients.insert(client.port(), client).is_some() {
unreachable!("BUG: added two clients with the same port");
if clients.insert(client.id(), client).is_some() {
unreachable!("BUG: added two clients with the same ID");
}
}

Expand All @@ -157,7 +157,7 @@ impl Instance {

let mut clients = self.clients.lock().await;

let Some(client) = clients.remove(&client.port()) else {
let Some(client) = clients.remove(&client.id()) else {
// TODO This happens for example when the language server died while
// client was still connected, and the client cleanup is attempted
// with the instance being gone already. We should try notifying
Expand Down Expand Up @@ -205,7 +205,7 @@ impl Instance {
}

/// Handle `textDocument/didOpen` client notification
pub async fn open_file(&self, port: u16, params: Value) -> Result<()> {
pub async fn open_file(&self, client_id: usize, params: Value) -> Result<()> {
let params = serde_json::from_value::<lsp::DidOpenTextDocumentParams>(params)
.context("parsing params")?;
let uri = &params.text_document.uri;
Expand All @@ -222,7 +222,7 @@ impl Instance {
}

clients
.get_mut(&port)
.get_mut(&client_id)
.expect("no matching client")
.files
.insert(uri.clone());
Expand All @@ -241,14 +241,14 @@ impl Instance {
}

/// Handle `textDocument/didClose` client notification
pub async fn close_file(&self, port: u16, params: Value) -> Result<()> {
pub async fn close_file(&self, client_id: usize, params: Value) -> Result<()> {
let params = serde_json::from_value::<lsp::DidCloseTextDocumentParams>(params)
.context("parsing params")?;

let mut clients = self.clients.lock().await;

clients
.get_mut(&port)
.get_mut(&client_id)
.context("no matching client")?
.files
.remove(&params.text_document.uri);
Expand All @@ -261,7 +261,7 @@ impl Instance {
/// definitely closed files
async fn close_all_files(
&self,
clients: &HashMap<u16, ClientData>,
clients: &HashMap<usize, ClientData>,
files: Vec<String>,
) -> Result<()> {
for uri in files {
Expand Down Expand Up @@ -643,12 +643,12 @@ async fn stdout_task(instance: Arc<Instance>, mut reader: LspReader<BufReader<Ch
// Forward successful response to the right client based on the
// Request ID tag.
match res.id.untag() {
(Some(Tag::Port(port)), id) => {
(Some(Tag::ClientId(client_id)), id) => {
res.id = id;
if let Some(client) = clients.get(&port) {
if let Some(client) = clients.get(&client_id) {
let _ = client.send_message(res.into()).await;
} else {
debug!(?port, "no matching client");
debug!(?client_id, "no matching client");
}
}
(Some(Tag::Drop), _) => {
Expand All @@ -664,13 +664,13 @@ async fn stdout_task(instance: Arc<Instance>, mut reader: LspReader<BufReader<Ch
// Forward the error response to the right client based on the
// Request ID tag.
match res.id.untag() {
(Some(Tag::Port(port)), id) => {
(Some(Tag::ClientId(client_id)), id) => {
warn!(?res, "server responded with error");
res.id = id;
if let Some(client) = clients.get(&port) {
if let Some(client) = clients.get(&client_id) {
let _ = client.send_message(res.into()).await;
} else {
debug!(?port, "no matching client");
debug!(?client_id, "no matching client");
}
}
(Some(Tag::Drop), _) => {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod client;
mod instance;
mod lsp;
mod socketwrapper;

pub mod config;
pub mod ext;
Expand Down
22 changes: 11 additions & 11 deletions src/lsp/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use super::jsonrpc::RequestId;

/// Additional metadata inserted into LSP RequestId
pub enum Tag {
/// Request is coming from a client connected on this port
Port(u16),
/// Request is coming from a client connected with this ID
ClientId(usize),
/// Response to this request should be ignored
Drop,
/// Response to this request should be forwarded
Expand All @@ -23,7 +23,7 @@ impl RequestId {
/// Serializes the ID to a string and prepends Tag
pub fn tag(&self, tag: Tag) -> RequestId {
let tag = match tag {
Tag::Port(port) => format!("port:{port}"),
Tag::ClientId(client_id) => format!("client_id:{client_id}"),
Tag::Drop => "drop".into(),
Tag::Forward => "forward".into(),
};
Expand All @@ -45,21 +45,21 @@ impl RequestId {
})
}

fn parse_port(input: &str) -> Result<(u16, &str)> {
let (port, rest) = input.split_once(':').context("missing`:`")?;
let port = u16::from_str(port).context("invalid port number")?;
Ok((port, rest))
fn parse_client_id(input: &str) -> Result<(usize, &str)> {
let (client_id, rest) = input.split_once(':').context("missing`:`")?;
let client_id = usize::from_str(client_id).context("invalid client ID")?;
Ok((client_id, rest))
}

fn parse_tag(input: &RequestId) -> Result<(Tag, RequestId)> {
let RequestId::String(input) = input else {
bail!("tagged id must be a String found `{input:?}`");
};

if let Some(rest) = input.strip_prefix("port:") {
let (port, rest) = parse_port(rest)?;
if let Some(rest) = input.strip_prefix("client_id:") {
let (client_id, rest) = parse_client_id(rest)?;
let inner_id = parse_inner_id(rest).context("failed to parse inner ID")?;
return Ok((Tag::Port(port), inner_id));
return Ok((Tag::ClientId(client_id), inner_id));
}

if let Some(rest) = input.strip_prefix("drop:") {
Expand Down Expand Up @@ -175,7 +175,7 @@ pub struct Instance {
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Client {
pub port: u16,
pub id: usize,
pub files: Vec<String>,
}

Expand Down
Loading
Loading