From 450cca83f141b263a4add02938ffeb5518cbddc2 Mon Sep 17 00:00:00 2001 From: Chrislearn Young Date: Sat, 16 Sep 2023 20:45:01 +0800 Subject: [PATCH] fix auto conn bug (#36) --- src/server/conn/auto.rs | 95 ++++++++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 26 deletions(-) diff --git a/src/server/conn/auto.rs b/src/server/conn/auto.rs index fe5f525..042284a 100644 --- a/src/server/conn/auto.rs +++ b/src/server/conn/auto.rs @@ -1,6 +1,12 @@ //! Http1 or Http2 connection. -use crate::{common::rewind::Rewind, rt::TokioIo}; +use std::future::Future; +use std::io::{Error as IoError, ErrorKind, Result as IoResult}; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; +use std::{error::Error as StdError, marker::Unpin, time::Duration}; + use bytes::Bytes; use http::{Request, Response}; use http_body::Body; @@ -10,8 +16,10 @@ use hyper::{ server::conn::{http1, http2}, service::Service, }; -use std::{error::Error as StdError, marker::Unpin, time::Duration}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use crate::{common::rewind::Rewind, rt::TokioIo}; type Result = std::result::Result>; @@ -57,7 +65,7 @@ impl Builder { } /// Bind a connection together with a [`Service`]. - pub async fn serve_connection(&self, mut io: I, service: S) -> Result<()> + pub async fn serve_connection(&self, io: I, service: S) -> Result<()> where S: Service, Response = Response> + Send, S::Future: Send + 'static, @@ -68,35 +76,70 @@ impl Builder { I: AsyncRead + AsyncWrite + Unpin + 'static, E: Http2ConnExec, { - enum Protocol { - H1, - H2, + let (version, io) = read_version(io).await?; + let io = TokioIo::new(io); + match version { + Version::H1 => self.http1.serve_connection(io, service).await?, + Version::H2 => self.http2.serve_connection(io, service).await?, } - let mut buf = Vec::new(); + Ok(()) + } +} +#[derive(Copy, Clone)] +enum Version { + H1, + H2, +} +async fn read_version<'a, A>(mut reader: A) -> IoResult<(Version, Rewind)> +where + A: AsyncRead + Unpin, +{ + let mut buf = [0; 24]; + let (version, buf) = ReadVersion { + reader: &mut reader, + buf: ReadBuf::new(&mut buf), + version: Version::H1, + _pin: PhantomPinned, + } + .await?; + Ok((version, Rewind::new_buffered(reader, Bytes::from(buf)))) +} +pin_project! { + struct ReadVersion<'a, A: ?Sized> { + reader: &'a mut A, + buf: ReadBuf<'a>, + version: Version, + // Make this future `!Unpin` for compatibility with async trait methods. + #[pin] + _pin: PhantomPinned, + } +} - let protocol = loop { - if buf.len() < 24 { - io.read_buf(&mut buf).await?; +impl Future for ReadVersion<'_, A> +where + A: AsyncRead + Unpin + ?Sized, +{ + type Output = IoResult<(Version, Vec)>; - let len = buf.len().min(H2_PREFACE.len()); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll)>> { + let this = self.project(); - if buf[0..len] != H2_PREFACE[0..len] { - break Protocol::H1; - } - } else { - break Protocol::H2; + while this.buf.remaining() != 0 { + if this.buf.filled() != &H2_PREFACE[0..this.buf.filled().len()] { + return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec()))); + } + // if our buffer is empty, then we need to read some data to continue. + let rem = this.buf.remaining(); + ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf))?; + if this.buf.remaining() == rem { + return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into(); } - }; - - let io = TokioIo::new(Rewind::new_buffered(io, Bytes::from(buf))); - - match protocol { - Protocol::H1 => self.http1.serve_connection(io, service).await?, - Protocol::H2 => self.http2.serve_connection(io, service).await?, } - - Ok(()) + if this.buf.filled() == H2_PREFACE { + *this.version = Version::H2; + } + return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec()))); } }