diff --git a/bin/src/cs.rs b/bin/src/cs.rs index f545903d..fd9b3723 100644 --- a/bin/src/cs.rs +++ b/bin/src/cs.rs @@ -25,6 +25,8 @@ use serde::{Deserialize, Serialize}; use std::{ffi::{c_char, c_void, CString}, fmt::Debug, process::ExitCode}; include!("cs_bindings.rs"); +use std::fs; +use std::path::Path; use std::process; use std::sync::Arc; use tokio::sync::Mutex; @@ -71,25 +73,78 @@ fn convert_to_go_slices(vec: &Vec) -> (GoSlice, Vec) { go_slices, ) } + +fn load_config(config_path: &str) -> Result> { + // 验证文件是否存在 + if !Path::new(config_path).exists() { + return Err(format!("Config file '{}' does not exist", config_path).into()); + } + + // 读取文件内容 + let config_content = fs::read_to_string(config_path) + .map_err(|e| format!("Failed to read config file '{}': {}", config_path, e))?; + + // 验证文件不为空 + if config_content.trim().is_empty() { + return Err("Config file is empty".into()); + } + + // 解析 YAML + let config: ConnectConfig = serde_yaml::from_str(&config_content) + .map_err(|e| format!("Failed to parse YAML config: {}", e))?; + + match serde_yaml::from_str::(&config_content) { + Ok(config) => println!("解析成功: {:?}", config), + Err(e) => println!("解析错误: {}", e), + } + + // 验证必要的字段 + validate_config(&config)?; + + Ok(config) +} + +fn validate_config(config: &ConnectConfig) -> Result<(), Box> { + // 配置验证 + if config.options.tcp_forward_addr.trim().is_empty() { + return Err("tcp_forward_addr cannot be empty".into()); + } + if config.options.tcp_forward_host_prefix.trim().is_empty() { + return Err("tcp_forward_host_prefix cannot be empty".into()); + } + Ok(()) +} + + pub fn run_connect(connect_args: ConnectArgs) { - let mut args = if let Some(config) = connect_args.config { - vec!["connect".to_owned(), "-config".to_owned(), config] + let mut args = if let Some(config_path) = &connect_args.config { + match load_config(config_path) { + Ok(config) => { + println!("Successfully loaded config from '{}'", config_path); + println!("Config details:"); + println!(" TCP Forward Address: {}", config.options.tcp_forward_addr); + println!(" TCP Forward Host Prefix: {}", config.options.tcp_forward_host_prefix); + config + }, + Err(e) => { + eprintln!("Error loading config: {}", e); + process::exit(1); + } + } } else { - vec!["connect".to_owned()] + println!("No config file specified, using default configuration"); + ConnectConfig::default() }; - let (args, go_str) = convert_to_go_slices(&args); info!("Run connect cmd."); let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async move { - let addr = "127.0.0.1:7000"; info!("Runtime started."); - let mut stream = tokio::net::TcpStream::connect(&addr).await.unwrap(); - let (connect_reader, connect_writer) = stream.into_split(); - info!("connect to client."); + let connect_reader = tokio::io::stdin(); + let connect_writer = tokio::io::stdout(); // let reader = Arc::new(Mutex::new(connect_reader)); // let writer = Arc::new(Mutex::new(connect_writer)); - if let Err(e) = process(connect_reader, connect_writer).await { - eprintln!("process p2p: {}", e); + if let Err(e) = process_connect(connect_reader, connect_writer, args).await { + eprintln!("process p2p connect: {}", e); process::exit(1); }; }); diff --git a/bin/src/peer/conn.rs b/bin/src/peer/conn.rs index 139111bd..8719069e 100644 --- a/bin/src/peer/conn.rs +++ b/bin/src/peer/conn.rs @@ -149,22 +149,22 @@ where })) } - pub async fn send_offer(self: Arc) -> Result<()> { - let pc = Arc::clone(&self.peer_connection); - let offer = pc.create_offer(None).await.context("create offer")?; - let sdp = serde_json::to_string(&offer).context("serialize answer")?; - let op = OP::OfferSDP(sdp); - write_json( - Arc::clone(&self.writer), - &serde_json::to_string(&op).context("encode op")?, - ) - .await - .context("write answer sdp to stdout")?; - pc.set_local_description(offer) - .await - .context("set local description")?; - Ok(()) - } + // pub async fn send_offer(self: Arc) -> Result<()> { + // let pc = Arc::clone(&self.peer_connection); + // let offer = pc.create_offer(None).await.context("create offer")?; + // let sdp = serde_json::to_string(&offer).context("serialize answer")?; + // let op = OP::OfferSDP(sdp); + // write_json( + // Arc::clone(&self.writer), + // &serde_json::to_string(&op).context("encode op")?, + // ) + // .await + // .context("write answer sdp to stdout")?; + // pc.set_local_description(offer) + // .await + // .context("set local description")?; + // Ok(()) + // } fn setup_data_channel(self: Arc, d: Arc) { let dc = Arc::clone(&d); diff --git a/bin/src/peer/connect.rs b/bin/src/peer/connect.rs index 84dad14e..7099ddf8 100644 --- a/bin/src/peer/connect.rs +++ b/bin/src/peer/connect.rs @@ -46,6 +46,8 @@ use reqwest::{Client, header}; use crate::peer::{read_json, write_json, LibError, OP, Config, ConnectConfig}; + +use super::ConnectOptions; pub(crate) struct ConnectPeerConnHandler { http_routes: HashMap, @@ -56,6 +58,7 @@ no_channel_id: AtomicUsize, peer_connection: Arc, timeout: u16, + options: Arc, } impl ConnectPeerConnHandler @@ -63,18 +66,12 @@ R: AsyncReadExt + Unpin + Send + 'static, W: AsyncWriteExt + Unpin + Send + 'static, { - pub async fn new(reader: R, writer: W) -> Result> { + pub async fn new(reader: R, writer: W, args: ConnectConfig) -> Result> { let reader = Arc::new(Mutex::new(reader)); let writer = Arc::new(Mutex::new(writer)); - // let json = timeout(Duration::from_secs(5), read_json(Arc::clone(&reader))) - // .await - // .context("read config json timeout")? - // .context("read config json")?; - // debug!("config json: {}", &json); - // let op = serde_json::from_str::(&json) - // .with_context(|| format!("deserialize config json failed: {}", json))?; + let tempargs = Arc::clone(&Arc::new(args.options)); let op: OP = OP::Config(Config { - stuns: vec!["stun:127.0.0.1:3478".to_owned()], + stuns: vec![tempargs.stun_addr.to_owned()], http_routes: HashMap::from([("@".to_owned(), "http://www.baidu.com".to_owned())]), ..Default::default() }); @@ -146,6 +143,7 @@ tcp_routes: config.tcp_routes, channel_count: Default::default(), no_channel_id: Default::default(), + options: tempargs, })) } @@ -198,17 +196,18 @@ Ok(body) } - pub async fn forward_data_with_server(self: Arc, yaml: &str) -> Result { - let ya = serde_yaml::from_str::(yaml)?; - let url = ya.options.tcp_forward_addr; + pub async fn forward_data_with_server(self: Arc, msg: &str) -> Result { + // let ya = serde_yaml::from_str::(yaml)?; + let options = Arc::clone(&self.options); + let url = &options.tcp_forward_addr; let method = "GET"; - let host = Some(ya.options.tcp_forward_host_prefix); + let host = Some(options.tcp_forward_host_prefix.as_str()); let headers = Some(vec![ ("Users-Agent".to_string(), "gt-connect".to_string()), ]); - let body = None; + let body = Some(msg); - let resp = self.send_http_request(&url, method, host.as_deref(), headers, body).await?; + let resp = self.send_http_request(&url, method, host, headers, body).await?; info!("Response from remote: {}", resp); Ok(true) } diff --git a/bin/src/peer/mod.rs b/bin/src/peer/mod.rs index 2a745782..111a853d 100644 --- a/bin/src/peer/mod.rs +++ b/bin/src/peer/mod.rs @@ -23,8 +23,9 @@ use log::*; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::io; -use tokio::io::{stdin, stdout}; +use tokio::io::{stdin, stdout, AsyncBufReadExt, BufReader}; use tokio::sync::Mutex; +use tokio::sync::mpsc; mod conn; mod connect; @@ -54,16 +55,36 @@ where W: io::AsyncWriteExt + Unpin + Send + 'static, { let handler = conn::PeerConnHandler::new(reader, writer).await?; - let offer_res = handler.clone().send_offer().await; - match offer_res { - Ok(()) => { - info!("offer sended."); + handler.handle().await +} + +pub async fn process_connect(reader: R, writer: W, args: ConnectConfig) -> Result<()> +where + R: io::AsyncReadExt + Unpin + Send + 'static, + W: io::AsyncWriteExt + Unpin + Send + 'static, +{ + let handler = connect::ConnectPeerConnHandler::new(reader, writer, args).await?; + let _ = Arc::clone(&handler).send_offer(); + let (tx, mut rx) = mpsc::channel(8); + // 在一个独立的任务中读取标准输入 + tokio::spawn(async move { + let mut stdin = BufReader::new(tokio::io::stdin()).lines(); + while let Some(line) = stdin.next_line().await.unwrap() { + tx.send(line).await.unwrap(); } - Err(err) => { - error!("offer send err: {}", err); + }); + // 在主任务中处理读取到的行 + while let Some(line) = rx.recv().await { + println!("Received line: {}", line); + // 将读取的行发送给服务端转发 + let handler = Arc::clone(&handler); + match handler.forward_data_with_server(&line).await { + Ok(_) => println!("Successfully forwarded: {}", line), + Err(e) => eprintln!("Error forwarding data: {}", e), } } - handler.handle().await + let _ = Arc::clone(&handler).handle().await; + Ok(()) } #[derive(Serialize, Deserialize, Debug, Default)] @@ -91,16 +112,18 @@ pub enum OP { } #[derive(Serialize, Deserialize, Debug, Default)] -#[serde(default, rename_all = "camelCase")] +#[serde(default)] pub struct ConnectConfig { + #[serde(rename = "type")] pub typ: String, pub options: ConnectOptions, } #[derive(Serialize, Deserialize, Debug, Default)] -#[serde(default, rename_all = "camelCase")] +#[serde(default)] pub struct ConnectOptions { pub remote: String, + pub stun_addr: String, pub tcp_forward_addr: String, pub tcp_forward_host_prefix: String, }