diff --git a/crates/tauri-driver/src/server.rs b/crates/tauri-driver/src/server.rs index 34304f035af6..e6a3b3815674 100644 --- a/crates/tauri-driver/src/server.rs +++ b/crates/tauri-driver/src/server.rs @@ -63,39 +63,46 @@ impl TauriOptions { } async fn handle( - client: Client, + client: Client>, mut req: Request, args: Args, -) -> Result>, Error> { +) -> Result, Error> { // manipulate a new session to convert options to the native driver format - if let (&Method::POST, "/session") = (req.method(), req.uri().path()) { - let (mut parts, body) = req.into_parts(); + let new_req: Request> = + if let (&Method::POST, "/session") = (req.method(), req.uri().path()) { + let (mut parts, body) = req.into_parts(); - // get the body from the future stream and parse it as json - let body = body.collect().await?.to_bytes().to_vec(); - let json: Value = serde_json::from_slice(&body)?; + // get the body from the future stream and parse it as json + let body = body.collect().await?.to_bytes().to_vec(); + let json: Value = serde_json::from_slice(&body)?; - // manipulate the json to convert from tauri option to native driver options - let json = map_capabilities(json); + // manipulate the json to convert from tauri option to native driver options + let json = map_capabilities(json); - // serialize json and update the content-length header to be accurate - let bytes = serde_json::to_vec(&json)?; - parts.headers.insert(CONTENT_LENGTH, bytes.len().into()); + // serialize json and update the content-length header to be accurate + let bytes = serde_json::to_vec(&json)?; + parts.headers.insert(CONTENT_LENGTH, bytes.len().into()); - req = Request::from_parts(parts, bytes); - } + Request::from_parts(parts, Full::new(bytes.into())) + } else { + let (parts, body) = req.into_parts(); + + let body = body.collect().await?.to_bytes().to_vec(); + + Request::from_parts(parts, Full::new(body.into())) + }; client - .request(forward_to_native_driver(req, args)?) + .request(forward_to_native_driver(new_req, args)?) .err_into() .await } /// Transform the request to a request for the native webdriver server. fn forward_to_native_driver( - mut req: Request, + mut req: Request>, args: Args, -) -> Result, Error> { +) -> Result>, Error> { let host: Authority = { let headers = req.headers_mut(); headers.remove("host").expect("hyper request has host") @@ -190,28 +197,36 @@ pub async fn run(args: Args, mut _driver: Child) -> Result<(), Error> { // set up a http1 server that uses the service we just created let srv = async move { - let listener = TcpListener::bind(address).await?; - loop { - let (stream, _) = listener.accept().await?; - let io = TokioIo::new(stream); - - tokio::task::spawn(async move { - if let Err(err) = auto::Builder::new(TokioExecutor::new()) - .http1() - .title_case_headers(true) - .preserve_header_case(true) - .serve_connection( - io, - service_fn(|request| handle(client.clone(), request, args.clone())), - ) - .await - { - println!("Error serving connection: {:?}", err); + if let Ok(listener) = TcpListener::bind(address).await { + loop { + let client = client.clone(); + let args = args.clone(); + if let Ok((stream, _)) = listener.accept().await { + let io = TokioIo::new(stream); + + tokio::task::spawn(async move { + if let Err(err) = auto::Builder::new(TokioExecutor::new()) + .http1() + .title_case_headers(true) + .preserve_header_case(true) + .serve_connection( + io, + service_fn(|request| handle(client.clone(), request, args.clone())), + ) + .await + { + println!("Error serving connection: {:?}", err); + } + }); + } else { + println!("accept new stream fail, ignore here"); } - }); + } + } else { + println!("can not listen to address: {:?}", address); } }; - srv.await?; + srv.await; #[cfg(unix)] {