Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
getong committed Jan 6, 2025
1 parent 8fd103f commit be9aec9
Showing 1 changed file with 51 additions and 36 deletions.
87 changes: 51 additions & 36 deletions crates/tauri-driver/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,39 +63,46 @@ impl TauriOptions {
}

async fn handle(
client: Client,
client: Client<HttpConnector, Full<Bytes>>,
mut req: Request<Incoming>,
args: Args,
) -> Result<Response<Full<Bytes>>, Error> {
) -> Result<Response<Incoming>, Error> {
// manipulate a new session to convert options to the native driver format
if let (&Method::POST, "/session") = (req.method(), req.uri().path()) {
let (mut parts, body) = req.into_parts();
let new_req: Request<Full<Bytes>> =
if let (&Method::POST, "/session") = (req.method(), req.uri().path()) {
let (mut parts, body) = req.into_parts();

// get the body from the future stream and parse it as json
let body = body.collect().await?.to_bytes().to_vec();
let json: Value = serde_json::from_slice(&body)?;
// get the body from the future stream and parse it as json
let body = body.collect().await?.to_bytes().to_vec();
let json: Value = serde_json::from_slice(&body)?;

// manipulate the json to convert from tauri option to native driver options
let json = map_capabilities(json);
// manipulate the json to convert from tauri option to native driver options
let json = map_capabilities(json);

// serialize json and update the content-length header to be accurate
let bytes = serde_json::to_vec(&json)?;
parts.headers.insert(CONTENT_LENGTH, bytes.len().into());
// serialize json and update the content-length header to be accurate
let bytes = serde_json::to_vec(&json)?;
parts.headers.insert(CONTENT_LENGTH, bytes.len().into());

req = Request::from_parts(parts, bytes);
}
Request::from_parts(parts, Full::new(bytes.into()))
} else {
let (parts, body) = req.into_parts();

let body = body.collect().await?.to_bytes().to_vec();

Request::from_parts(parts, Full::new(body.into()))
};

client
.request(forward_to_native_driver(req, args)?)
.request(forward_to_native_driver(new_req, args)?)
.err_into()
.await
}

/// Transform the request to a request for the native webdriver server.
fn forward_to_native_driver(
mut req: Request<Incoming>,
mut req: Request<Full<Bytes>>,
args: Args,
) -> Result<Request<Incoming>, Error> {
) -> Result<Request<Full<Bytes>>, Error> {
let host: Authority = {
let headers = req.headers_mut();
headers.remove("host").expect("hyper request has host")
Expand Down Expand Up @@ -190,28 +197,36 @@ pub async fn run(args: Args, mut _driver: Child) -> Result<(), Error> {

// set up a http1 server that uses the service we just created
let srv = async move {
let listener = TcpListener::bind(address).await?;
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = auto::Builder::new(TokioExecutor::new())
.http1()
.title_case_headers(true)
.preserve_header_case(true)
.serve_connection(
io,
service_fn(|request| handle(client.clone(), request, args.clone())),
)
.await
{
println!("Error serving connection: {:?}", err);
if let Ok(listener) = TcpListener::bind(address).await {
loop {
let client = client.clone();
let args = args.clone();
if let Ok((stream, _)) = listener.accept().await {
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = auto::Builder::new(TokioExecutor::new())
.http1()
.title_case_headers(true)
.preserve_header_case(true)
.serve_connection(
io,
service_fn(|request| handle(client.clone(), request, args.clone())),
)
.await
{
println!("Error serving connection: {:?}", err);
}
});
} else {
println!("accept new stream fail, ignore here");
}
});
}
} else {
println!("can not listen to address: {:?}", address);
}
};
srv.await?;
srv.await;

#[cfg(unix)]
{
Expand Down

0 comments on commit be9aec9

Please sign in to comment.