Skip to content

Commit

Permalink
implement wrappers around tcp/unix,listen/stream
Browse files Browse the repository at this point in the history
This allows the wrappers to be used in all circumstances where TcpListen,
TcpStream, tcp::OwnedReadHalf, and tcp::OwnedWriteHalf are currently
used.
  • Loading branch information
EliteTK committed Jun 3, 2024
1 parent 3f0de93 commit 1c6025a
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 11 deletions.
25 changes: 23 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ anyhow = "1.0.53"
clap = { version = "4.3.0", features = ["derive", "env"] }
directories = "4.0.1"
percent-encoding = "2.3.1"
pin-project = "1.1.5"
serde = { version = "1.0.186" }
serde_derive = { version = "1.0.186" }
serde_json = "1.0.78"
Expand Down
5 changes: 2 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ use anyhow::{bail, ensure, Context, Result};
use percent_encoding::percent_decode_str;
use serde_json::Value;
use tokio::io::BufReader;
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::{mpsc, Mutex};
use tokio::task;
Expand All @@ -21,10 +19,11 @@ use crate::lsp::jsonrpc::{
};
use crate::lsp::transport::{LspReader, LspWriter};
use crate::lsp::InitializeParams;
use crate::socketwrapper::{OwnedReadHalf, OwnedWriteHalf, Stream};

/// Read first client message and dispatch lsp mux commands
pub async fn process(
socket: TcpStream,
socket: Stream,
client_id: usize,
instance_map: Arc<Mutex<InstanceMap>>,
) -> Result<()> {
Expand Down
4 changes: 2 additions & 2 deletions src/ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ use std::env;
use anyhow::{bail, Context, Result};
use serde::de::{DeserializeOwned, IgnoredAny};
use tokio::io::BufReader;
use tokio::net::TcpStream;

use crate::config::Config;
use crate::lsp::ext::{self, LspMuxOptions, StatusResponse};
use crate::lsp::jsonrpc::{Message, Request, RequestId, Version};
use crate::lsp::transport::{LspReader, LspWriter};
use crate::lsp::{InitializationOptions, InitializeParams};
use crate::socketwrapper::Stream;

pub async fn ext_request<T>(config: &Config, method: ext::Request) -> Result<T>
where
T: DeserializeOwned,
{
let (reader, writer) = TcpStream::connect(config.connect)
let (reader, writer) = Stream::connect_tcp(config.connect)
.await
.context("connect")?
.into_split();
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod client;
mod instance;
mod lsp;
mod socketwrapper;

pub mod config;
pub mod ext;
Expand Down
4 changes: 2 additions & 2 deletions src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ use std::env;

use anyhow::{bail, Context as _, Result};
use tokio::io::{self, BufStream};
use tokio::net::TcpStream;

use crate::config::Config;
use crate::lsp::ext::{LspMuxOptions, Request};
use crate::lsp::jsonrpc::Message;
use crate::lsp::transport::{LspReader, LspWriter};
use crate::lsp::{InitializationOptions, InitializeParams};
use crate::socketwrapper::Stream;

pub async fn run(config: &Config, server: String, args: Vec<String>) -> Result<()> {
let cwd = env::current_dir()
Expand All @@ -23,7 +23,7 @@ pub async fn run(config: &Config, server: String, args: Vec<String>) -> Result<(
}
}

let mut stream = TcpStream::connect(config.connect)
let mut stream = Stream::connect_tcp(config.connect)
.await
.context("connect")?;
let mut stdio = BufStream::new(io::join(io::stdin(), io::stdout()));
Expand Down
6 changes: 4 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use std::sync::atomic::{AtomicUsize, Ordering};

use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tokio::task;
use tracing::{error, info, info_span, warn, Instrument};

use crate::client;
use crate::config::Config;
use crate::instance::InstanceMap;
use crate::socketwrapper::Listener;

pub async fn run(config: &Config) -> Result<()> {
let instance_map = InstanceMap::new(config).await;
let next_client_id = AtomicUsize::new(0);

let listener = TcpListener::bind(config.listen).await.context("listen")?;
let listener = Listener::bind_tcp(config.listen.into())
.await
.context("listen")?;
loop {
match listener.accept().await {
Ok((socket, _addr)) => {
Expand Down
201 changes: 201 additions & 0 deletions src/socketwrapper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, net};

use pin_project::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{tcp, unix, TcpListener, TcpStream, ToSocketAddrs, UnixListener, UnixStream};

pub enum SocketAddr {
Ip(net::SocketAddr),
Unix(tokio::net::unix::SocketAddr),
}

impl From<net::SocketAddr> for SocketAddr {
fn from(val: net::SocketAddr) -> Self {
SocketAddr::Ip(val)
}
}

impl From<tokio::net::unix::SocketAddr> for SocketAddr {
fn from(val: tokio::net::unix::SocketAddr) -> Self {
SocketAddr::Unix(val)
}
}

#[pin_project(project = OwnedReadHalfProj)]
pub enum OwnedReadHalf {
Tcp(#[pin] tcp::OwnedReadHalf),
Unix(#[pin] unix::OwnedReadHalf),
}

impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.project() {
OwnedReadHalfProj::Tcp(tcp) => tcp.poll_read(cx, buf),
OwnedReadHalfProj::Unix(unix) => unix.poll_read(cx, buf),
}
}
}

#[pin_project(project = OwnedWriteHalfProj)]
pub enum OwnedWriteHalf {
Tcp(#[pin] tcp::OwnedWriteHalf),
Unix(#[pin] unix::OwnedWriteHalf),
}

impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.project() {
OwnedWriteHalfProj::Tcp(tcp) => tcp.poll_write(cx, buf),
OwnedWriteHalfProj::Unix(unix) => unix.poll_write(cx, buf),
}
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
match self.project() {
OwnedWriteHalfProj::Tcp(tcp) => tcp.poll_write_vectored(cx, bufs),
OwnedWriteHalfProj::Unix(unix) => unix.poll_write_vectored(cx, bufs),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project() {
OwnedWriteHalfProj::Tcp(tcp) => tcp.poll_flush(cx),
OwnedWriteHalfProj::Unix(unix) => unix.poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project() {
OwnedWriteHalfProj::Tcp(tcp) => tcp.poll_shutdown(cx),
OwnedWriteHalfProj::Unix(unix) => unix.poll_shutdown(cx),
}
}
}

#[pin_project(project = StreamProj)]
pub enum Stream {
Tcp(#[pin] TcpStream),
Unix(#[pin] UnixStream),
}

impl Stream {
pub async fn connect_tcp<A: ToSocketAddrs>(addr: A) -> io::Result<Stream> {
Ok(Stream::Tcp(TcpStream::connect(addr).await?))
}

pub async fn connect_unix<P: AsRef<Path>>(addr: P) -> io::Result<Stream> {
Ok(Stream::Unix(UnixStream::connect(addr).await?))
}

pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
match self {
Stream::Tcp(tcp) => {
let (read, write) = tcp.into_split();
(OwnedReadHalf::Tcp(read), OwnedWriteHalf::Tcp(write))
}
Stream::Unix(unix) => {
let (read, write) = unix.into_split();
(OwnedReadHalf::Unix(read), OwnedWriteHalf::Unix(write))
}
}
}
}

impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.project() {
StreamProj::Tcp(tcp) => tcp.poll_read(cx, buf),
StreamProj::Unix(unix) => unix.poll_read(cx, buf),
}
}
}

impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.project() {
StreamProj::Tcp(tcp) => tcp.poll_write(cx, buf),
StreamProj::Unix(unix) => unix.poll_write(cx, buf),
}
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
match self.project() {
StreamProj::Tcp(tcp) => tcp.poll_write_vectored(cx, bufs),
StreamProj::Unix(unix) => unix.poll_write_vectored(cx, bufs),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project() {
StreamProj::Tcp(tcp) => tcp.poll_flush(cx),
StreamProj::Unix(unix) => unix.poll_flush(cx),
}
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
match self.project() {
StreamProj::Tcp(tcp) => tcp.poll_shutdown(cx),
StreamProj::Unix(unix) => unix.poll_shutdown(cx),
}
}
}

pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}

impl Listener {
pub async fn bind_tcp(addr: net::SocketAddr) -> io::Result<Listener> {
Ok(Listener::Tcp(TcpListener::bind(addr).await?))
}

pub fn bind_unix<T: AsRef<Path>>(addr: T) -> io::Result<Listener> {
match fs::remove_file(&addr) {
Ok(()) => (),
Err(e) if e.kind() == io::ErrorKind::NotFound => (),
Err(e) => return Err(e),
}
Ok(Listener::Unix(UnixListener::bind(addr)?))
}

pub async fn accept(&self) -> io::Result<(Stream, SocketAddr)> {
match self {
Listener::Tcp(tcp) => {
let (stream, addr) = tcp.accept().await?;
Ok((Stream::Tcp(stream), addr.into()))
}
Listener::Unix(unix) => {
let (stream, addr) = unix.accept().await?;
Ok((Stream::Unix(stream), addr.into()))
}
}
}
}

0 comments on commit 1c6025a

Please sign in to comment.